Skip to content

Commit c809925

Browse files
committed
update
1 parent ab8336f commit c809925

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

primus/tools/benchmark/strided_allgather_bench.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,13 @@ def benchmark_strided_allgather(
163163
avg_s = total_s / iters
164164
# Each iteration moves group_size * size_bytes worth of data to each rank
165165
total_bytes = size_bytes * group_size
166-
gib_s = bytes_to_gib_per_s(total_bytes, avg_s)
166+
algobw = bytes_to_gib_per_s(total_bytes, avg_s)
167+
busbw = algobw * (group_size - 1) / group_size
167168

168169
if rank == print_rank:
169170
print(
170171
f"[RANK-{rank}][StridedAllGather][Parallel][Group-{group_id}][Ranks-{group_ranks}] size={format_size_mb(size_bytes)} "
171-
f"avg={avg_s*1e3:.2f} ms agg_bw={gib_s:.2f} GiB/s"
172+
f"avg={avg_s*1e3:.2f} ms algobw={algobw:.2f} GiB/s busbw={busbw:.2f} GiB/s"
172173
)
173174
else:
174175
for g_id in range(num_groups):
@@ -183,11 +184,12 @@ def benchmark_strided_allgather(
183184
total_s = t1 - t0
184185
avg_s = total_s / iters
185186
total_bytes = size_bytes * group_size
186-
gib_s = bytes_to_gib_per_s(total_bytes, avg_s)
187+
algobw = bytes_to_gib_per_s(total_bytes, avg_s)
188+
busbw = algobw * (group_size - 1) / group_size
187189
if rank == print_rank:
188190
print(
189191
f"[RANK-{rank}][StridedAllGather][Single][Group-{g_id}][Ranks-{group_ranks}] size={format_size_mb(size_bytes)} "
190-
f"avg={avg_s*1e3:.2f} ms agg_bw={gib_s:.2f} GiB/s"
192+
f"avg={avg_s*1e3:.2f} ms algobw={algobw:.2f} GiB/s busbw={busbw:.2f} GiB/s"
191193
)
192194
barrier()
193195
torch.cuda.synchronize()

0 commit comments

Comments
 (0)