Skip to content

Commit 35d1489

Browse files
authored
Merge pull request #172 from ybenvidia/constant-message-size-graph
Fix generated html graph for constant message size
2 parents 7d07e04 + acfad47 commit 35d1489

File tree

3 files changed

+29
-65
lines changed

3 files changed

+29
-65
lines changed

src/cloudai/report_generator/tool/bokeh_report_tool.py

Lines changed: 25 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def create_figure(
4848
height: int = 308,
4949
x_axis_type: str = "linear",
5050
tools: str = "pan,wheel_zoom,box_zoom,reset,save",
51+
x_range: Optional[Range1d] = None,
5152
) -> figure:
5253
"""
5354
Create a configured Bokeh figure with common settings.
@@ -61,6 +62,7 @@ def create_figure(
6162
height (int): Height of the plot.
6263
x_axis_type (str): Type of the x-axis ('linear' or 'log').
6364
tools (str): Tools to include in the plot.
65+
x_range (Range1d): Range for the x-axis, optional.
6466
6567
Returns:
6668
figure: A Bokeh figure configured with the specified parameters.
@@ -76,6 +78,10 @@ def create_figure(
7678
y_range=y_range,
7779
align="center",
7880
)
81+
82+
if x_range is not None:
83+
plot.x_range = x_range
84+
7985
return plot
8086

8187
def add_sol_line(
@@ -164,60 +170,6 @@ def add_linear_xy_line_plot(
164170

165171
self.plots.append(p)
166172

167-
def add_log_x_linear_y_single_line_plot(
168-
self,
169-
title: str,
170-
x_column: str,
171-
y_column: str,
172-
x_axis_label: str,
173-
y_axis_label: str,
174-
df: pd.DataFrame,
175-
sol: Optional[float] = None,
176-
color: str = "black",
177-
):
178-
"""
179-
Create a single line plot with a logarithmic x-axis and linear y-axis.
180-
181-
Args:
182-
title (str): Title of the plot.
183-
x_column (str): The column used for the x-axis values.
184-
y_column (str): The column used for the y-axis values.
185-
x_axis_label (str): Label for the x-axis.
186-
y_axis_label (str): Label for the y-axis.
187-
df (pd.DataFrame): DataFrame containing the data.
188-
sol (Optional[float]): Speed-of-light performance reference line.
189-
color (str): Color of the line in the plot.
190-
191-
This function sets up a Bokeh figure and plots a single line of data. It also
192-
optionally adds a reference line (SOL) if provided. The x-axis uses a logarithmic
193-
scale, and custom JavaScript is used for tick formatting to enhance readability.
194-
"""
195-
x_min, x_max = self.find_min_max(df, x_column)
196-
y_min, y_max = self.find_min_max(df, y_column, sol)
197-
198-
# Create a Bokeh figure with logarithmic x-axis
199-
p = self.create_figure(
200-
title="CloudAI " + title,
201-
x_axis_label=x_axis_label,
202-
y_axis_label=y_axis_label,
203-
x_axis_type="log",
204-
y_range=Range1d(start=0, end=y_max * 1.1),
205-
)
206-
207-
# Add main line plot
208-
p.line(x=x_column, y=y_column, source=ColumnDataSource(df), line_width=2, color=color, legend_label=y_column)
209-
210-
self.add_sol_line(p, df, x_column, y_column, sol)
211-
212-
p.legend.location = "bottom_right"
213-
214-
p.xaxis.ticker = calculate_power_of_two_ticks(x_min, x_max)
215-
p.xaxis.formatter = CustomJSTickFormatter(code=bokeh_size_unit_js_tick_formatter)
216-
p.xaxis.major_label_orientation = pi / 4
217-
218-
# Append plot to internal list for future rendering
219-
self.plots.append(p)
220-
221173
def add_log_x_linear_y_multi_line_plot(
222174
self,
223175
title: str,
@@ -246,12 +198,25 @@ def add_log_x_linear_y_multi_line_plot(
246198
_, col_max = self.find_min_max(df, y_column, sol)
247199
y_max = max(y_max, col_max)
248200

201+
x_axis_type = "log"
202+
x_range = None
203+
204+
# Check if x_min equals x_max - constant message size
205+
if x_min == x_max:
206+
# Use iteration number as x-axis
207+
df["iteration"] = range(1, len(df) + 1)
208+
x_column = "iteration"
209+
x_axis_label = "Iteration"
210+
x_axis_type = "linear"
211+
x_range = Range1d(start=1, end=len(df))
212+
249213
p = self.create_figure(
250214
title="CloudAI " + title,
251215
x_axis_label=x_axis_label,
252216
y_axis_label=y_axis_label,
253-
x_axis_type="log",
217+
x_axis_type=x_axis_type,
254218
y_range=Range1d(start=0, end=y_max * 1.1),
219+
x_range=x_range,
255220
)
256221

257222
# Adding lines for each data type specified
@@ -265,10 +230,11 @@ def add_log_x_linear_y_multi_line_plot(
265230

266231
p.legend.location = "bottom_right"
267232

268-
# Setting up custom tick formatter for log scale readability
269-
p.xaxis.ticker = calculate_power_of_two_ticks(x_min, x_max)
270-
p.xaxis.formatter = CustomJSTickFormatter(code=bokeh_size_unit_js_tick_formatter)
271-
p.xaxis.major_label_orientation = pi / 4
233+
if x_axis_type == "log":
234+
# Setting up custom tick formatter for log scale readability
235+
p.xaxis.ticker = calculate_power_of_two_ticks(x_min, x_max)
236+
p.xaxis.formatter = CustomJSTickFormatter(code=bokeh_size_unit_js_tick_formatter)
237+
p.xaxis.major_label_orientation = pi / 4
272238

273239
self.plots.append(p)
274240

src/cloudai/schema/test_template/nccl_test/report_generation_strategy.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,15 +120,14 @@ def _generate_bokeh_report(
120120
("Busbw (GB/s) In-place", "green", "In-place Bus Bandwidth"),
121121
]
122122
for col_name, color, title in line_plots:
123-
report_tool.add_log_x_linear_y_single_line_plot(
123+
report_tool.add_log_x_linear_y_multi_line_plot(
124124
title=f"{test_name} {title}",
125125
x_column="Size (B)",
126-
y_column=col_name,
126+
y_columns=[(col_name, color)],
127127
x_axis_label="Message Size",
128128
y_axis_label="Bandwidth (GB/s)",
129129
df=df,
130130
sol=sol,
131-
color=color,
132131
)
133132

134133
combined_columns = [("Busbw (GB/s) Out-of-place", "blue"), ("Busbw (GB/s) In-place", "green")]

src/cloudai/schema/test_template/ucc_test/report_generation_strategy.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,14 @@ def _generate_plots(self, df: pd.DataFrame, directory_path: str, sol: Optional[f
9696
report_tool = BokehReportTool(directory_path)
9797
line_plots = [("Bandwidth (GB/s) avg", "black", "Average Bandwidth")]
9898
for col_name, color, title in line_plots:
99-
report_tool.add_log_x_linear_y_single_line_plot(
99+
report_tool.add_log_x_linear_y_multi_line_plot(
100100
title=title,
101101
x_column="Size (B)",
102-
y_column=col_name,
102+
y_columns=[(col_name, color)],
103103
x_axis_label="Message Size",
104104
y_axis_label="Bandwidth (GB/s)",
105105
df=df,
106106
sol=sol,
107-
color=color,
108107
)
109108

110109
combined_columns = [

0 commit comments

Comments
 (0)