Skip to content

Commit 54ba07d

Browse files
authored
Merge pull request #697 from NVIDIA/am/ench
Small formatting improvements
2 parents b8e91d8 + 6bddb9b commit 54ba07d

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

src/cloudai/cli/handlers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,11 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int:
150150
break
151151
step, action = result
152152
env.test_run.step = step
153+
logging.info(f"Running step {step} (of {agent.max_steps}) with action {action}")
153154
observation, reward, done, info = env.step(action)
154155
feedback = {"trial_index": step, "value": reward}
155156
agent.update_policy(feedback)
156-
logging.info(f"Step {step}: Observation: {observation}, Reward: {reward}")
157+
logging.info(f"Step {step}: Observation: {[round(obs, 4) for obs in observation]}, Reward: {reward:.4f}")
157158

158159
if args.mode == "run":
159160
runner.runner.test_scenario.test_runs = original_test_runs

src/cloudai/configurator/cloudai_gym.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]:
104104
logging.info("Constraint check failed. Skipping step.")
105105
return [-1.0], -1.0, True, {}
106106

107-
logging.info(f"Running step {self.test_run.step} with action {action}")
108107
new_tr = copy.deepcopy(self.test_run)
109108
new_tr.output_path = self.runner.get_job_output_path(new_tr)
110109
self.runner.test_scenario.test_runs = [new_tr]

src/cloudai/workloads/nixl_bench/nixl_summary_report.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def create_charts(self, cmp_groups: list[GroupedTestRuns]) -> list[bk.figure]:
7878
dfs = [self.extract_data_as_df(item.tr) for item in group.items]
7979
charts.extend(
8080
[
81-
self.create_chart(group, dfs, "Latecy", list(self.INFO_COLUMNS), ["avg_lat"], "Time (us)"),
81+
self.create_chart(group, dfs, "Latency", list(self.INFO_COLUMNS), ["avg_lat"], "Time (us)"),
8282
self.create_chart(group, dfs, "Bandwidth", list(self.INFO_COLUMNS), ["bw_gb_sec"], "Busbw (GB/s)"),
8383
]
8484
)

0 commit comments

Comments
 (0)