Skip to content

Commit d808b7d

Browse files
authored
Merge pull request #207 from TaekyungHeo/rm-env-var
Remove Unused env_vars From Initialization Code
2 parents 648c6a4 + b9d217d commit d808b7d

21 files changed

+34
-110
lines changed

src/cloudai/_core/command_gen_strategy.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ class CommandGenStrategy(TestTemplateStrategy):
3131
@abstractmethod
3232
def gen_exec_command(
3333
self,
34-
env_vars: Dict[str, str],
3534
cmd_args: Dict[str, str],
3635
extra_env_vars: Dict[str, str],
3736
extra_cmd_args: str,
@@ -43,7 +42,6 @@ def gen_exec_command(
4342
Generate the execution command for a test based on the given parameters.
4443
4544
Args:
46-
env_vars (Dict[str, str]): Environment variables for the test.
4745
cmd_args (Dict[str, str]): Command-line arguments for the test.
4846
extra_env_vars (Dict[str, str]): Additional environment variables.
4947
extra_cmd_args (str): Additional command-line arguments.

src/cloudai/_core/json_gen_strategy.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ class JsonGenStrategy(TestTemplateStrategy):
3131
@abstractmethod
3232
def gen_json(
3333
self,
34-
env_vars: Dict[str, str],
3534
cmd_args: Dict[str, str],
3635
extra_env_vars: Dict[str, str],
3736
extra_cmd_args: str,
@@ -44,7 +43,6 @@ def gen_json(
4443
Generate the Kubernetes job specification based on the given parameters.
4544
4645
Args:
47-
env_vars (Dict[str, str]): Environment variables for the job.
4846
cmd_args (Dict[str, str]): Command-line arguments for the job.
4947
extra_env_vars (Dict[str, str]): Additional environment variables.
5048
extra_cmd_args (str): Additional command-line arguments.

src/cloudai/_core/test.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def __init__(
3434
name: str,
3535
description: str,
3636
test_template: TestTemplate,
37-
env_vars: Dict[str, str],
3837
cmd_args: Dict[str, str],
3938
extra_env_vars: Dict[str, str],
4039
extra_cmd_args: str,
@@ -52,7 +51,6 @@ def __init__(
5251
name (str): Name of the test.
5352
description (str): Description of the test.
5453
test_template (TestTemplate): Test template object.
55-
env_vars (Dict[str, str]): Environment variables for the test.
5654
cmd_args (Dict[str, str]): Command-line arguments for the test.
5755
extra_env_vars (Dict[str, str]): Extra environment variables.
5856
extra_cmd_args (str): Extra command-line arguments.
@@ -67,7 +65,6 @@ def __init__(
6765
self.name = name
6866
self.description = description
6967
self.test_template = test_template
70-
self.env_vars = env_vars
7168
self.cmd_args = cmd_args
7269
self.extra_env_vars = extra_env_vars
7370
self.extra_cmd_args = extra_cmd_args
@@ -89,7 +86,6 @@ def __repr__(self) -> str:
8986
return (
9087
f"Test(name={self.name}, description={self.description}, "
9188
f"test_template={self.test_template.name}, "
92-
f"env_vars={self.env_vars}, "
9389
f"cmd_args={self.cmd_args}, "
9490
f"extra_env_vars={self.extra_env_vars}, "
9591
f"extra_cmd_args={self.extra_cmd_args}, "
@@ -118,7 +114,6 @@ def gen_exec_command(
118114
nodes = []
119115

120116
return self.test_template.gen_exec_command(
121-
self.env_vars,
122117
self.cmd_args,
123118
self.extra_env_vars,
124119
self.extra_cmd_args,
@@ -154,7 +149,6 @@ def gen_json(
154149
nodes = []
155150

156151
return self.test_template.gen_json(
157-
self.env_vars,
158152
self.cmd_args,
159153
self.extra_env_vars,
160154
self.extra_cmd_args,

src/cloudai/_core/test_parser.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def _fetch_strategy( # noqa: D417
8888
],
8989
system_type: Type[System],
9090
test_template_type: Type[TestTemplate],
91-
env_vars: Dict[str, Any],
9291
cmd_args: Dict[str, Any],
9392
) -> Optional[
9493
Union[
@@ -108,7 +107,6 @@ def _fetch_strategy( # noqa: D417
108107
The strategy interface to fetch.
109108
system_type (Type[System]): The system type.
110109
test_template_type (Type[TestTemplate]): The test template type.
111-
env_vars (Dict[str, Any]): Environment variables.
112110
cmd_args (Dict[str, Any]): Command-line arguments.
113111
114112
Returns:
@@ -119,7 +117,7 @@ def _fetch_strategy( # noqa: D417
119117
strategy_type = registry.strategies_map.get(key)
120118
if strategy_type:
121119
if issubclass(strategy_type, TestTemplateStrategy):
122-
return strategy_type(self.system, env_vars, cmd_args)
120+
return strategy_type(self.system, cmd_args)
123121
else:
124122
return strategy_type()
125123

@@ -128,13 +126,12 @@ def _fetch_strategy( # noqa: D417
128126
)
129127
return None
130128

131-
def _get_test_template(self, name: str, env_vars: Dict[str, Any], cmd_args: Dict[str, Any]) -> TestTemplate:
129+
def _get_test_template(self, name: str, cmd_args: Dict[str, Any]) -> TestTemplate:
132130
"""
133131
Dynamically retrieves the appropriate TestTemplate subclass based on the given name.
134132
135133
Args:
136134
name (str): The name of the test template.
137-
env_vars (Dict[str, Any]): Environment variables.
138135
cmd_args (Dict[str, Any]): Command-line arguments.
139136
140137
Returns:
@@ -146,32 +143,32 @@ def _get_test_template(self, name: str, env_vars: Dict[str, Any], cmd_args: Dict
146143
if not test_template_class:
147144
raise ValueError(f"Unsupported test_template name: {name}")
148145

149-
obj = test_template_class(system=self.system, name=name, env_vars=env_vars, cmd_args=cmd_args)
146+
obj = test_template_class(system=self.system, name=name, cmd_args=cmd_args)
150147
obj.install_strategy = cast(
151-
InstallStrategy, self._fetch_strategy(InstallStrategy, type(obj.system), type(obj), env_vars, cmd_args)
148+
InstallStrategy, self._fetch_strategy(InstallStrategy, type(obj.system), type(obj), cmd_args)
152149
)
153150
obj.command_gen_strategy = cast(
154151
CommandGenStrategy,
155-
self._fetch_strategy(CommandGenStrategy, type(obj.system), type(obj), env_vars, cmd_args),
152+
self._fetch_strategy(CommandGenStrategy, type(obj.system), type(obj), cmd_args),
156153
)
157154
obj.json_gen_strategy = cast(
158155
JsonGenStrategy,
159-
self._fetch_strategy(JsonGenStrategy, type(obj.system), type(obj), env_vars, cmd_args),
156+
self._fetch_strategy(JsonGenStrategy, type(obj.system), type(obj), cmd_args),
160157
)
161158
obj.job_id_retrieval_strategy = cast(
162159
JobIdRetrievalStrategy,
163-
self._fetch_strategy(JobIdRetrievalStrategy, type(obj.system), type(obj), env_vars, cmd_args),
160+
self._fetch_strategy(JobIdRetrievalStrategy, type(obj.system), type(obj), cmd_args),
164161
)
165162
obj.job_status_retrieval_strategy = cast(
166163
JobStatusRetrievalStrategy,
167-
self._fetch_strategy(JobStatusRetrievalStrategy, type(obj.system), type(obj), env_vars, cmd_args),
164+
self._fetch_strategy(JobStatusRetrievalStrategy, type(obj.system), type(obj), cmd_args),
168165
)
169166
obj.report_generation_strategy = cast(
170167
ReportGenerationStrategy,
171-
self._fetch_strategy(ReportGenerationStrategy, type(obj.system), type(obj), env_vars, cmd_args),
168+
self._fetch_strategy(ReportGenerationStrategy, type(obj.system), type(obj), cmd_args),
172169
)
173170
obj.grading_strategy = cast(
174-
GradingStrategy, self._fetch_strategy(GradingStrategy, type(obj.system), type(obj), env_vars, cmd_args)
171+
GradingStrategy, self._fetch_strategy(GradingStrategy, type(obj.system), type(obj), cmd_args)
175172
)
176173
return obj
177174

@@ -187,7 +184,6 @@ def _parse_data(self, data: Dict[str, Any]) -> Test:
187184
"""
188185
test_def = self.load_test_definition(data)
189186

190-
env_vars = {} # this field doesn't exist in Test or TestTemplate TOMLs
191187
"""
192188
There are:
193189
1. global_env_vars, used in System
@@ -198,7 +194,7 @@ def _parse_data(self, data: Dict[str, Any]) -> Test:
198194
extra_cmd_args = test_def.extra_args_str
199195

200196
test_template_name = data.get("test_template_name", "")
201-
test_template = self._get_test_template(test_template_name, env_vars, cmd_args)
197+
test_template = self._get_test_template(test_template_name, cmd_args)
202198

203199
if not test_template:
204200
test_name = data.get("name", "Unnamed Test")
@@ -214,7 +210,6 @@ def _parse_data(self, data: Dict[str, Any]) -> Test:
214210
name=test_def.name,
215211
description=data.get("description", ""),
216212
test_template=test_template,
217-
env_vars=env_vars,
218213
cmd_args=cmd_args,
219214
extra_env_vars=extra_env_vars,
220215
extra_cmd_args=extra_cmd_args,

src/cloudai/_core/test_scenario_parser.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def _create_section_test_run(self, section: str, test_info: Dict[str, Any]) -> T
137137
name=original_test.name,
138138
description=original_test.description,
139139
test_template=original_test.test_template,
140-
env_vars=copy.deepcopy(original_test.env_vars),
141140
cmd_args=copy.deepcopy(original_test.cmd_args),
142141
extra_env_vars=copy.deepcopy(original_test.extra_env_vars),
143142
extra_cmd_args=original_test.extra_cmd_args,

src/cloudai/_core/test_template.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ class TestTemplate:
3838
3939
Attributes
4040
name (str): Unique name of the test template.
41-
env_vars (Dict[str, Any]): Default environment variables.
4241
cmd_args (Dict[str, Any]): Default command-line arguments.
4342
logger (logging.Logger): Logger for the test template.
4443
install_strategy (InstallStrategy): Strategy for installing test prerequisites.
@@ -56,7 +55,6 @@ def __init__(
5655
self,
5756
system: System,
5857
name: str,
59-
env_vars: Dict[str, Any],
6058
cmd_args: Dict[str, Any],
6159
) -> None:
6260
"""
@@ -65,12 +63,10 @@ def __init__(
6563
Args:
6664
system (System): System configuration for the test template.
6765
name (str): Name of the test template.
68-
env_vars (Dict[str, Any]): Environment variables.
6966
cmd_args (Dict[str, Any]): Command-line arguments.
7067
"""
7168
self.system = system
7269
self.name = name
73-
self.env_vars = env_vars
7470
self.cmd_args = cmd_args
7571
self.install_strategy: Optional[InstallStrategy] = None
7672
self.command_gen_strategy: Optional[CommandGenStrategy] = None
@@ -127,7 +123,6 @@ def uninstall(self) -> InstallStatusResult:
127123

128124
def gen_exec_command(
129125
self,
130-
env_vars: Dict[str, str],
131126
cmd_args: Dict[str, str],
132127
extra_env_vars: Dict[str, str],
133128
extra_cmd_args: str,
@@ -141,7 +136,6 @@ def gen_exec_command(
141136
This method must be implemented by subclasses.
142137
143138
Args:
144-
env_vars (Dict[str, str]): Environment variables for the test.
145139
cmd_args (Dict[str, str]): Command-line arguments for the test.
146140
extra_env_vars (Dict[str, str]): Extra environment variables.
147141
extra_cmd_args (str): Extra command-line arguments.
@@ -160,7 +154,6 @@ def gen_exec_command(
160154
"by calling the appropriate registration function for the system type."
161155
)
162156
return self.command_gen_strategy.gen_exec_command(
163-
env_vars,
164157
cmd_args,
165158
extra_env_vars,
166159
extra_cmd_args,
@@ -171,7 +164,6 @@ def gen_exec_command(
171164

172165
def gen_json(
173166
self,
174-
env_vars: Dict[str, str],
175167
cmd_args: Dict[str, str],
176168
extra_env_vars: Dict[str, str],
177169
extra_cmd_args: str,
@@ -184,7 +176,6 @@ def gen_json(
184176
Generate a JSON string representing the Kubernetes job specification for this test using this template.
185177
186178
Args:
187-
env_vars (Dict[str, str]): Environment variables for the test.
188179
cmd_args (Dict[str, str]): Command-line arguments for the test.
189180
extra_env_vars (Dict[str, str]): Extra environment variables.
190181
extra_cmd_args (str): Extra command-line arguments.
@@ -204,7 +195,6 @@ def gen_json(
204195
"by calling the appropriate registration function for the system type."
205196
)
206197
return self.json_gen_strategy.gen_json(
207-
env_vars,
208198
cmd_args,
209199
extra_env_vars,
210200
extra_cmd_args,

src/cloudai/_core/test_template_strategy.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,42 +28,23 @@ class TestTemplateStrategy:
2828
2929
Attributes
3030
system (System): The system schema object.
31-
env_vars (Dict[str, Any]): Default environment variables.
3231
cmd_args (Dict[str, Any]): Default command-line arguments.
33-
default_env_vars (Dict[str, str]): Constructed default environment variables.
3432
default_cmd_args (Dict[str, str]): Constructed default command-line arguments.
3533
"""
3634

3735
__test__ = False
3836

39-
def __init__(self, system: System, env_vars: Dict[str, Any], cmd_args: Dict[str, Any]) -> None:
37+
def __init__(self, system: System, cmd_args: Dict[str, Any]) -> None:
4038
"""
4139
Initialize a TestTemplateStrategy instance with system configuration, env variables, and command-line arguments.
4240
4341
Args:
4442
system (System): The system configuration for the test.
45-
env_vars (Dict[str, Any]): Default environment variables.
4643
cmd_args (Dict[str, Any]): Default command-line arguments.
4744
"""
4845
self.system = system
49-
self.env_vars = env_vars
5046
self.cmd_args = cmd_args
51-
self.default_env_vars = self._construct_default_env_vars()
5247
self.default_cmd_args = self._construct_default_cmd_args()
53-
self.default_env_vars.update(system.global_env_vars)
54-
55-
def _construct_default_env_vars(self) -> Dict[str, str]:
56-
"""
57-
Construct the default environment variables for the test template.
58-
59-
Returns
60-
Dict[str, str]: A dictionary containing the default environment variables.
61-
"""
62-
return {
63-
key: value["default"]
64-
for key, value in self.env_vars.items()
65-
if isinstance(value, dict) and "default" in value
66-
}
6748

6849
def _construct_default_cmd_args(self) -> Dict[str, str]:
6950
"""

src/cloudai/schema/test_template/chakra_replay/slurm_command_gen_strategy.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,14 @@ class ChakraReplaySlurmCommandGenStrategy(SlurmCommandGenStrategy):
2525

2626
def gen_exec_command(
2727
self,
28-
env_vars: Dict[str, str],
2928
cmd_args: Dict[str, str],
3029
extra_env_vars: Dict[str, str],
3130
extra_cmd_args: str,
3231
output_path: Path,
3332
num_nodes: int,
3433
nodes: List[str],
3534
) -> str:
36-
final_env_vars = self._override_env_vars(self.default_env_vars, env_vars)
37-
final_env_vars = self._override_env_vars(final_env_vars, extra_env_vars)
35+
final_env_vars = self._override_env_vars(self.system.global_env_vars, extra_env_vars)
3836
final_cmd_args = self._override_cmd_args(self.default_cmd_args, cmd_args)
3937
env_vars_str = self._format_env_vars(final_env_vars)
4038

src/cloudai/schema/test_template/jax_toolbox/slurm_command_gen_strategy.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,12 @@
2727
class JaxToolboxSlurmCommandGenStrategy(SlurmCommandGenStrategy):
2828
"""Command generation strategy for JaxToolbox tests on Slurm systems."""
2929

30-
def __init__(self, system: SlurmSystem, env_vars: Dict[str, Any], cmd_args: Dict[str, Any]) -> None:
31-
super().__init__(system, env_vars, cmd_args)
30+
def __init__(self, system: SlurmSystem, cmd_args: Dict[str, Any]) -> None:
31+
super().__init__(system, cmd_args)
3232
self.test_name = ""
3333

3434
def gen_exec_command(
3535
self,
36-
env_vars: Dict[str, str],
3736
cmd_args: Dict[str, str],
3837
extra_env_vars: Dict[str, str],
3938
extra_cmd_args: str,
@@ -50,7 +49,6 @@ def gen_exec_command(
5049
the appropriate strategy for handling thresholds.
5150
5251
Args:
53-
env_vars (Dict[str, str]): Environment variables for the job.
5452
cmd_args (Dict[str, str]): Command-line arguments for the job.
5553
extra_env_vars (Dict[str, str]): Additional environment variables.
5654
extra_cmd_args (str): Additional command arguments.
@@ -63,8 +61,7 @@ def gen_exec_command(
6361
"""
6462
self.test_name = self._extract_test_name(cmd_args)
6563

66-
final_env_vars = self._override_env_vars(self.default_env_vars, env_vars)
67-
final_env_vars = self._override_env_vars(final_env_vars, extra_env_vars)
64+
final_env_vars = self._override_env_vars(self.system.global_env_vars, extra_env_vars)
6865
cmd_args["output_path"] = str(output_path)
6966

7067
combine_threshold_bytes = int(final_env_vars["COMBINE_THRESHOLD"])

src/cloudai/schema/test_template/nccl_test/kubernetes_json_gen_strategy.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ class NcclTestKubernetesJsonGenStrategy(JsonGenStrategy):
2525

2626
def gen_json(
2727
self,
28-
env_vars: Dict[str, str],
2928
cmd_args: Dict[str, str],
3029
extra_env_vars: Dict[str, str],
3130
extra_cmd_args: str,
@@ -34,8 +33,7 @@ def gen_json(
3433
num_nodes: int,
3534
nodes: List[str],
3635
) -> Dict[Any, Any]:
37-
final_env_vars = self._override_env_vars(self.default_env_vars, env_vars)
38-
final_env_vars = self._override_env_vars(final_env_vars, extra_env_vars)
36+
final_env_vars = self._override_env_vars(self.system.global_env_vars, extra_env_vars)
3937
final_cmd_args = self._override_cmd_args(self.default_cmd_args, cmd_args)
4038
final_num_nodes = self._determine_num_nodes(num_nodes, nodes)
4139
job_spec = self._create_job_spec(

0 commit comments

Comments
 (0)