Skip to content
Open
Show file tree
Hide file tree
Changes from 146 commits
Commits
Show all changes
151 commits
Select commit Hold shift + click to select a range
e711b62
Initial plan
Copilot Feb 11, 2026
c881bc5
Replace gtest/gtest.h with framework.hpp in all unit tests
Copilot Feb 11, 2026
e227fdc
Convert mp_unit tests from gtest to framework.hpp
Copilot Feb 11, 2026
1e32e17
Address code review comments
Copilot Feb 11, 2026
eafa6fb
Add custom test framework and code coverage support
Copilot Feb 11, 2026
3d8a2e7
Add --gtest_filter support to framework
Copilot Feb 11, 2026
a10aff5
Address code review feedback
Copilot Feb 11, 2026
1818709
Fix CodeQL workflow by disabling test builds
Copilot Feb 11, 2026
5657e4a
Initial plan for fixing test build with GPU bypass
Copilot Feb 11, 2026
0eae34c
Fix test framework for building with Docker
Copilot Feb 11, 2026
4823583
Move FailHelper and SkipHelper into mscclpp::test namespace
Copilot Feb 11, 2026
403b2fb
Remove unnecessary CMake build artifacts from PR
Copilot Feb 11, 2026
305d157
Remove PerfTestResult and reuse TestResult directly
Copilot Feb 11, 2026
b1f458e
Convert test framework identifiers from snake_case to camelCase
Copilot Feb 11, 2026
6da12fa
Comprehensive plan for refactoring
Copilot Feb 11, 2026
7e4365f
Add performance test filtering and remove HTML coverage
Copilot Feb 11, 2026
b59196b
Integrate perf tests into unit_tests and add CI coverage step
Copilot Feb 11, 2026
ba0451a
Remove build2 CMake artifacts from repository
Copilot Feb 11, 2026
50f6a24
Remove test/perf/ directory completely
Copilot Feb 11, 2026
e26f8ab
Address PR review comments
Copilot Feb 11, 2026
7003fec
Simplify filter matching to use substring matching
Copilot Feb 11, 2026
30b9891
simplifying
chhwang Feb 19, 2026
b6ce0f2
simplify
chhwang Feb 19, 2026
d2efc2f
coverage update
chhwang Feb 19, 2026
4afbf78
minor
chhwang Feb 19, 2026
e40c72b
license text update
chhwang Feb 19, 2026
bed85b5
codecov upload
chhwang Feb 19, 2026
4d9acea
badge
chhwang Feb 19, 2026
b693d1b
lint issue
chhwang Feb 19, 2026
2b4adcc
fix lint
chhwang Feb 19, 2026
b64536f
Merge branch 'main' into copilot/remove-gtest-use-custom-framework
chhwang Feb 19, 2026
dcdd3fe
update UT CI
chhwang Feb 20, 2026
caeec75
updates
chhwang Feb 20, 2026
b9609f8
add coverage flags
chhwang Feb 20, 2026
41695ba
Merge branch 'main' into copilot/remove-gtest-use-custom-framework
chhwang Feb 20, 2026
febdbf9
WIP; need amd fix
chhwang Feb 21, 2026
c4afbe1
Merge branch 'main' into copilot/remove-gtest-use-custom-framework
chhwang Feb 23, 2026
04ebd9b
fix coverage file path
chhwang Feb 23, 2026
54e46ba
rocm fix wip
chhwang Feb 23, 2026
6c2bc8f
coverage fix
chhwang Feb 23, 2026
d0c709e
Fix Codecov token usage in coverage upload step
chhwang Feb 23, 2026
edda25d
Merge branch 'main' into copilot/remove-gtest-use-custom-framework
chhwang Feb 23, 2026
2f02d38
Merge branch 'main' into copilot/remove-gtest-use-custom-framework
chhwang Feb 24, 2026
2adf4a4
use variable group
chhwang Feb 24, 2026
98b023a
rocm fixes
chhwang Feb 24, 2026
22e5efb
gdrcopy install in container
chhwang Feb 24, 2026
2f27d7d
Update coverage report to exclude additional directories in lcov command
chhwang Feb 24, 2026
d88ee8d
Refine coverage report to include only mscclpp source and include dir…
chhwang Feb 24, 2026
11e27e2
Update coverage report commands to handle errors and adjust paths
chhwang Feb 24, 2026
25f31b4
updates
chhwang Feb 24, 2026
75dfdd9
Merge branch 'main' into chhwang/fix-ib-no-atomic
chhwang Feb 24, 2026
ac4d713
updates
chhwang Feb 24, 2026
ac022c3
a few updates
chhwang Feb 25, 2026
72407af
License
chhwang Feb 25, 2026
8effd97
License
chhwang Feb 25, 2026
fd7358d
License, lint
chhwang Feb 25, 2026
67d1706
optimized recv loop
chhwang Feb 26, 2026
060982d
updates
chhwang Feb 26, 2026
6b2f819
Merge branch 'main' into chhwang/fix-ib-no-atomic
chhwang Feb 26, 2026
eb99a26
Merge branch 'main' into copilot/remove-gtest-use-custom-framework
chhwang Feb 27, 2026
8c3a436
update CI
chhwang Feb 27, 2026
f4b8574
Merge branch 'main' into copilot/remove-gtest-use-custom-framework
chhwang Mar 3, 2026
3b56b08
data direct
chhwang Mar 4, 2026
448ceb6
updates
chhwang Mar 5, 2026
7ce841b
Updates
chhwang Mar 5, 2026
bbb9c10
Update Docker image
chhwang Mar 6, 2026
60ff32c
updates
chhwang Mar 6, 2026
00583da
separate pipeline for codecov
chhwang Mar 6, 2026
c699b8a
az pipeline refactoring
chhwang Mar 7, 2026
284d913
Merge branch 'main' into copilot/remove-gtest-use-custom-framework
chhwang Mar 7, 2026
75ac8be
fix
chhwang Mar 7, 2026
e0c7ddb
fix
chhwang Mar 7, 2026
c40a233
fix
chhwang Mar 7, 2026
375bc13
fix
chhwang Mar 7, 2026
bcb392f
updates
chhwang Mar 8, 2026
ea1dd65
fix
chhwang Mar 8, 2026
d6a6fa2
simplified
chhwang Mar 8, 2026
a9cf938
fix
chhwang Mar 9, 2026
6647338
debugging
chhwang Mar 10, 2026
7a87c2c
debugging
chhwang Mar 10, 2026
cf505d7
debugging
chhwang Mar 10, 2026
757c0ec
debugging
chhwang Mar 11, 2026
e2a5be4
debugging
chhwang Mar 11, 2026
2a705f5
fix merge
chhwang Mar 11, 2026
a38bd9d
Merge branch 'main' into copilot/remove-gtest-use-custom-framework
chhwang Mar 11, 2026
e2a9692
fix merge
chhwang Mar 11, 2026
2c4bab8
fix
chhwang Mar 16, 2026
a937ce4
debugging
chhwang Mar 16, 2026
d66d7e4
debugging
chhwang Mar 17, 2026
5a65cc7
debugging
chhwang Mar 17, 2026
2297a3d
updates
chhwang Mar 18, 2026
2756221
update
chhwang Mar 18, 2026
bff76d5
Fix TearDown() handling and replace assert() in perf tests
Copilot Mar 18, 2026
6082648
fix for npkit
chhwang Mar 18, 2026
79a0149
updates
chhwang Mar 18, 2026
0200532
Merge branch 'copilot/remove-gtest-use-custom-framework' into chhwang…
chhwang Mar 18, 2026
80f554e
Merge branch 'main' into chhwang/fix-ib-no-atomic
chhwang Mar 26, 2026
67f9933
fix data direct
chhwang Apr 1, 2026
d1124fb
revert
chhwang Apr 1, 2026
144046b
revert
chhwang Apr 1, 2026
f8e94d9
disable mlx5dv_reg_dmabuf_mr
chhwang Apr 1, 2026
4cf5332
updates
chhwang Apr 1, 2026
848b89b
64-bit token reconstruction
chhwang Apr 1, 2026
ff4d825
Merge branch 'main' into chhwang/fix-ib-no-atomic
chhwang Apr 1, 2026
94d0508
prerequisites update
chhwang Apr 1, 2026
553fd3b
lint
chhwang Apr 1, 2026
53099a7
Merge branch 'main' into chhwang/fix-ib-no-atomic
chhwang Apr 2, 2026
f62633a
mlx5dv bug fixes & enhanced unit tests perf reporting
chhwang Apr 4, 2026
b04fa2d
lint
chhwang Apr 4, 2026
a4bb8fb
add debugging code
mahdiehghazim Apr 3, 2026
194a79f
add sendrecv correctness check
mahdiehghazim Apr 3, 2026
49979e5
tune #instances and remoce extra barriers
mahdiehghazim Mar 19, 2026
27fbddb
update the executor so we have message size range
mahdiehghazim Mar 17, 2026
d07a1ba
show scale in output
mahdiehghazim Mar 17, 2026
a191f16
add scripts
mahdiehghazim Mar 17, 2026
b1cc649
re-format output
mahdiehghazim Mar 17, 2026
a4118ea
update the number of instances
mahdiehghazim Mar 17, 2026
289f89d
update
Binyang2014 Mar 12, 2026
1e6d493
update
Binyang2014 Mar 9, 2026
251873c
update
Binyang2014 Mar 9, 2026
07d97f6
Unique QP per channel and env-controlled GID index
Binyang2014 Mar 9, 2026
8cecfee
debug
Binyang2014 Mar 9, 2026
ad56728
fix
Binyang2014 Mar 8, 2026
e487f83
debug
Binyang2014 Mar 6, 2026
2c3f125
add changes from ib and connection
mahdiehghazim Apr 6, 2026
1a065dd
add help scripts
mahdiehghazim Apr 6, 2026
812f6cf
fix hang on 4 ranks and make send/recv test more like nccl-test
mahdiehghazim Apr 7, 2026
3f2ade2
add barrier
mahdiehghazim Apr 7, 2026
6d8fb00
add extra signal/wait and avoid local flush
mahdiehghazim Apr 9, 2026
96defbd
add executor for testing
mahdiehghazim Apr 10, 2026
68690ec
revert dsl
mahdiehghazim Apr 10, 2026
54c2f50
merge main
Apr 10, 2026
f83a557
Add sendrecv support with double-buffer to executor_test
Apr 11, 2026
76fdd1d
WIP
Apr 11, 2026
57f7be6
WIP
Apr 11, 2026
65139d6
WIP
mahdiehghazim Apr 11, 2026
456ef7e
fix
mahdiehghazim Apr 11, 2026
36abcbe
WIP
mahdiehghazim Apr 11, 2026
a2a1b89
for 4 nodes
Binyang2014 Apr 13, 2026
1fd5ed8
update the script
mahdiehghazim Apr 13, 2026
4a17b64
update
Binyang2014 May 20, 2026
3a1e2d4
clean
Binyang2014 May 21, 2026
8a42fe2
revert
Binyang2014 May 22, 2026
7784407
WIP
Binyang2014 May 22, 2026
e600520
WIP
Binyang2014 May 22, 2026
4e09967
Merge branch 'main' into binyli/GB200
Binyang2014 May 22, 2026
3bd24e1
WIP
Binyang2014 May 22, 2026
142e794
WIP
Binyang2014 May 22, 2026
fd27fa0
Simplify executor_test: unify single/double-buffer paths via lists
Binyang2014 May 22, 2026
4463595
Merge branch 'main' into binyli/GB200
Binyang2014 May 26, 2026
bde8d45
WIP
Binyang2014 May 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions python/mscclpp/language/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,46 @@ def init_buffers(self):
}
rank_buffers.append(buffers)
return rank_buffers


class SendRecv(Collective):
"""A SendRecv collective communication pattern.

SendRecv performs a point-to-point send/receive operation.
Each rank sends its input buffer to the next rank and receives data from the
previous rank into its output buffer.

This operation creates input and output buffers both sized by chunk_factor,
as each rank sends and receives the same amount of data.
"""

def __init__(self, num_ranks, chunk_factor, inplace):
"""Initialize a new SendRecv collective.

Args:
num_ranks (int): The number of ranks participating in the SendRecv.
chunk_factor (int): The size factor for data chunks.
inplace (bool): Whether the operation should be performed in-place.

Example:
>>> sendrecv = SendRecv(num_ranks=4, chunk_factor=1, inplace=False)
"""
Collective.__init__(self, num_ranks, chunk_factor, inplace)
self.name = "sendrecv"

def init_buffers(self):
"""Initialize buffers for the SendRecv operation.

Creates input and output buffers both sized by chunk_factor.

Returns:
list: A list of buffer dictionaries, one for each rank.
"""
rank_buffers = []
for rank in range(self.num_ranks):
buffers = {
BufferType.input: BaseBuffer(rank, BufferType.input, 0, self.chunk_factor),
BufferType.output: BaseBuffer(rank, BufferType.output, 0, self.chunk_factor),
}
rank_buffers.append(buffers)
return rank_buffers
15 changes: 10 additions & 5 deletions python/mscclpp/language/rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,16 @@ def __init__(self, rank: int, buffer_type: BufferType, offset: int, size: int):
self.size = offset + size

def __getitem__(self, key):
if self.offset + key.stop > self.size:
raise RuntimeError(
f"Index range from {self.offset + key.start} - {self.offset + key.stop} is out of bounds for buffer {self.buffer_type}. Buffer size: {self.size}"
)
return Chunk(self.rank, self.buffer_type, self.offset + key.start, key.stop - key.start)
if isinstance(key, slice):
start = key.start if key.start is not None else 0
stop = key.stop if key.stop is not None else (self.size - self.offset)
if self.offset + stop > self.size:
raise RuntimeError(
f"Index range from {self.offset + start} - {self.offset + stop} is out of bounds for buffer {self.buffer_type}. Buffer size: {self.size}"
)
return Chunk(self.rank, self.buffer_type, self.offset + start, stop - start)
else:
raise TypeError(f"Buffer indices must be slices, not {type(key).__name__}")
Comment thread
Binyang2014 marked this conversation as resolved.
Outdated


class Buffer(BaseBuffer):
Expand Down
89 changes: 89 additions & 0 deletions python/mscclpp/language/tests/multi_node/send_recv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
from mscclpp.language.channel import *
from mscclpp.language.rank import *
from mscclpp.language.general import *
from mscclpp.language.program import *
from mscclpp.language.collectives import *


def send_recv(name, nnodes, gpus_per_node, split_mask, instances):
gpu_size = nnodes * gpus_per_node
collective = SendRecv(gpu_size, 1, False)
with CollectiveProgram(
name,
collective,
gpu_size,
protocol="Simple",
num_threads_per_block=1024,
use_double_scratch_buffer=False,
min_message_size=0,
max_message_size=2**64 - 1,
instances=instances,
):
# Creating separate port channels for next and prev directions.
# When prev and next are the same peer (e.g., 2-node ring), both channels go to the same peer
# and get distinct tags. To ensure cross-rank tag matching (rank A's prev_channel signal
# arrives at rank B's next_channel wait), we create channels in opposite order for the
# "higher" rank so that tags cross-match:
# Lower rank: [next(tag0), prev(tag1)]
# Higher rank: [prev(tag0), next(tag1)]
# Then lower.prev(tag1) == higher.next(tag1) and higher.prev(tag0) == lower.next(tag0)
# When prev != next (3+ nodes), each channel targets a different peer so each gets tag 0
# and this ordering doesn't matter.
group_size = split_mask + 1
num_groups = gpu_size // group_size
next_channels = {} # channel for sending to next rank
prev_channels = {} # channel for receiving from prev rank
prev_next_ids = {}
for node in range(nnodes):
for gpu in range(gpus_per_node):
global_rank_id = gpu + gpus_per_node * node
position_in_group = global_rank_id & split_mask
group_id = global_rank_id // group_size
next_group_id = (group_id + 1) % num_groups
next_global_rank_id = next_group_id * group_size + position_in_group
prev_group_id = (group_id - 1 + num_groups) % num_groups
prev_global_rank_id = prev_group_id * group_size + position_in_group
Comment thread
Binyang2014 marked this conversation as resolved.
Outdated
if prev_global_rank_id == next_global_rank_id and global_rank_id > prev_global_rank_id:
# Higher rank: create prev first, then next (swapped order)
prev_channels[global_rank_id] = PortChannel(prev_global_rank_id, global_rank_id)
next_channels[global_rank_id] = PortChannel(next_global_rank_id, global_rank_id)
else:
# Lower rank or different peers: create next first, then prev
next_channels[global_rank_id] = PortChannel(next_global_rank_id, global_rank_id)
prev_channels[global_rank_id] = PortChannel(prev_global_rank_id, global_rank_id)
prev_next_ids[global_rank_id] = (prev_global_rank_id, next_global_rank_id)

# sync with the next rank and the previous rank in the group
for node in range(nnodes):
for gpu in range(gpus_per_node):
global_rank_id = gpu + gpus_per_node * node
prev_global_rank_id, next_global_rank_id = prev_next_ids[global_rank_id]
prev_channels[global_rank_id].signal(tb=0, data_sync=SyncType.none)
next_channels[global_rank_id].wait(tb=0, data_sync=SyncType.after)

src_rank = Rank(global_rank_id)
src_buffer = src_rank.get_input_buffer()
dst_rank = Rank(next_global_rank_id)
dst_buffer = dst_rank.get_output_buffer()

next_channels[global_rank_id].put_with_signal(dst_buffer[:], src_buffer[:], tb=0)
prev_channels[global_rank_id].wait(tb=0, data_sync=SyncType.none)

print(JSON())


parser = argparse.ArgumentParser()

parser.add_argument("--name", type=str, help="name of the program")
parser.add_argument("--nnodes", type=int, default=1, help="number of nodes")
parser.add_argument("--gpus_per_node", type=int, help="number of gpus per node")
parser.add_argument("--split_mask", type=lambda x: int(x, 0), default=0x3, help="split mask (e.g. 0x3)")
parser.add_argument("--instances", type=int, default=4, help="number of instances")

args = parser.parse_args()

send_recv(args.name, args.nnodes, args.gpus_per_node, args.split_mask, args.instances)
114 changes: 87 additions & 27 deletions python/test/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from mscclpp.utils import KernelBuilder, pack
import os
import struct
from typing import Callable, Union

import cupy as cp
from mpi4py import MPI
Expand All @@ -34,13 +35,16 @@ def parse_dtype(dtype_str):
raise ValueError(f"Unknown data type: {dtype_str}")


def bench_time(n_iters: int, n_graph_iters: int, func):
# capture cuda graph for n_iters of the kernel launch
def bench_time(n_iters: int, n_graph_iters: int, func: Union[Callable, list[Callable]]):
"""Benchmark execution time. func can be a single callable or a list of 2 for double-buffer."""
stream = cp.cuda.Stream(non_blocking=True)
with stream:
stream.begin_capture()
for i in range(n_iters):
func(stream)
if isinstance(func, list):
func[i % 2](stream)
else:
func(stream)
graph = stream.end_capture()

# now run a warm up round
Expand All @@ -61,16 +65,19 @@ def bench_time(n_iters: int, n_graph_iters: int, func):

def bench_correctness(
collective: str,
input_buf: cp.ndarray,
result_buf: cp.ndarray,
test_buf: cp.ndarray,
input_buf: Union[cp.ndarray, list[cp.ndarray]],
result_buf: Union[cp.ndarray, list[cp.ndarray]],
test_buf: Union[cp.ndarray, list[cp.ndarray]],
dtype_str: str,
rank: int,
num_ranks: int,
n_iters: int,
func,
func: Union[Callable, list[Callable]],
split_mask: int = 0,
):
"""Validate correctness. For sendrecv, buffers and func are lists of 2 for double-buffer."""
type_size = cp.dtype(parse_dtype(dtype_str)).itemsize
double_buf = isinstance(input_buf, list)

fill_data_kernel_name = "fill_data_%s" % dtype_str
if "allgather" in collective:
Expand All @@ -79,8 +86,10 @@ def bench_correctness(
coll = "reduce_scatter"
elif "allreduce" in collective:
coll = "all_reduce"
elif "sendrecv" in collective:
coll = "send_recv"
else:
coll = "all_to_all"
raise ValueError(f"Unknown collective: {collective}")
test_data_kernel_name = "test_data_%s_%s" % (coll, dtype_str)
Comment thread
Binyang2014 marked this conversation as resolved.

file_dir = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -97,11 +106,27 @@ def bench_correctness(
with stream:
stream.begin_capture()
for i in range(n_iters):
fill_data_params = pack(input_buf) + struct.pack("Q", input_buf.nbytes // type_size) + pack(rank, i)
if double_buf:
idx = i % 2
cur_input = input_buf[idx]
cur_result = result_buf[idx]
cur_test = test_buf[idx]
cur_func = func[idx]
else:
cur_input = input_buf
cur_result = result_buf
cur_test = test_buf
cur_func = func

fill_data_params = (
pack(cur_input) + struct.pack("Q", cur_input.nbytes // type_size) + pack(rank, i, split_mask)
)
fill_data_kernel.launch_kernel(fill_data_params, nblocks, nthreads, 0, stream)
func(stream)
cur_func(stream)
test_data_params = (
pack(result_buf, test_buf) + struct.pack("Q", input_buf.nbytes // type_size) + pack(num_ranks, rank, i)
pack(cur_result, cur_test)
+ struct.pack("Q", cur_input.nbytes // type_size)
+ pack(num_ranks, rank, i, split_mask)
)
test_data_kernel.launch_kernel(test_data_params, nblocks, nthreads, 0, stream)
graph = stream.end_capture()
Expand Down Expand Up @@ -147,6 +172,13 @@ def build_bufs(
assert (size % type_size) == 0, "size %d not multiple of type size %d" % (size, type_size)
nelems = size // type_size

# Sendrecv uses double buffering: return lists of 2 buffers
if "sendrecv" in collective:
input_bufs = [GpuBuffer(nelems, dtype=dtype) for _ in range(2)]
result_bufs = [GpuBuffer(nelems, dtype=dtype) for _ in range(2)]
test_bufs = [cp.zeros(nelems, dtype=dtype) for _ in range(2)]
return input_bufs, result_bufs, test_bufs, nelems

if "allgather" in collective:
assert (nelems % num_ranks) == 0, "nelems %d not multiple of num_ranks %d" % (nelems, num_ranks)
nelems_input = nelems if in_place else nelems // num_ranks
Expand All @@ -173,7 +205,7 @@ def build_bufs(

test_buf = cp.zeros(nelems, dtype=dtype)

return input_buf, result_buf, test_buf
return input_buf, result_buf, test_buf, nelems


def main(
Expand All @@ -184,6 +216,7 @@ def main(
packet_type: PacketType = PacketType.LL16,
n_iters: int = 10,
n_graph_iters: int = 10,
split_mask: int = 0,
):
mscclpp_group = CommGroup(MPI.COMM_WORLD)
cp.cuda.Device(mscclpp_group.my_rank % mscclpp_group.nranks_per_node).use()
Expand All @@ -195,7 +228,7 @@ def main(
collective = execution_plan.collective

dtype = parse_dtype(dtype_str)
input_buf, result_buf, test_buf = build_bufs(
input_buf, result_buf, test_buf, nelem = build_bufs(
collective,
size,
in_place,
Expand All @@ -204,17 +237,36 @@ def main(
mscclpp_group.nranks,
)

executor_func = lambda stream: executor.execute(
mscclpp_group.my_rank,
input_buf.data.ptr,
result_buf.data.ptr,
input_buf.nbytes,
result_buf.nbytes,
dtype_to_mscclpp_dtype(dtype_str),
execution_plan,
stream.ptr,
packet_type,
)
sendrecv_mode = "sendrecv" in collective

if sendrecv_mode:
# Double-buffer: create two executor funcs, one per buffer pair
executor_funcs = []
for idx in range(2):
func = lambda stream, i=idx: executor.execute(
mscclpp_group.my_rank,
input_buf[i].data.ptr,
result_buf[i].data.ptr,
input_buf[i].nbytes,
result_buf[i].nbytes,
dtype_to_mscclpp_dtype(dtype),
execution_plan,
stream.ptr,
packet_type,
)
executor_funcs.append(func)
else:
executor_func = lambda stream: executor.execute(
mscclpp_group.my_rank,
input_buf.data.ptr,
result_buf.data.ptr,
input_buf.nbytes,
result_buf.nbytes,
dtype_to_mscclpp_dtype(dtype),
execution_plan,
Comment thread
Binyang2014 marked this conversation as resolved.
Outdated
stream.ptr,
packet_type,
)

mscclpp_group.barrier()
bench_correctness(
Expand All @@ -226,17 +278,21 @@ def main(
mscclpp_group.my_rank,
mscclpp_group.nranks,
n_iters,
executor_func,
executor_funcs if sendrecv_mode else executor_func,
split_mask=split_mask,
)

mscclpp_group.barrier()
execution_time = bench_time(n_iters, n_graph_iters, executor_func)
execution_time = bench_time(n_iters, n_graph_iters, executor_funcs if sendrecv_mode else executor_func)
if npkit_dump_dir is not None:
npkit.dump(npkit_dump_dir)
npkit.shutdown()

result_nbytes = result_buf[0].nbytes if sendrecv_mode else result_buf.nbytes
print(
f"Rank: {mscclpp_group.my_rank} Execution time: {execution_time} us, "
f"data size: {result_buf.nbytes} bytes data type: {dtype_str} "
f"data size: {result_nbytes} bytes data type: {dtype().dtype.name} "
Comment thread
Binyang2014 marked this conversation as resolved.
Outdated
f"bandwidth: {result_nbytes / (execution_time * 1e-6) / (1024**3):.2f} GB/s, "
f"packet type: {packet_type}"
)
executor = None
Expand All @@ -252,6 +308,9 @@ def main(
parser.add_argument("--packet_type", type=str, default="LL16", help="Choose from LL8, LL16")
parser.add_argument("--n_iters", type=int, default=10)
parser.add_argument("--n_graph_iters", type=int, default=10)
parser.add_argument(
"--split_mask", type=lambda x: int(x, 0), default=0x0, help="split mask for sendrecv (e.g. 0x3)"
)
args = parser.parse_args()

packet_type = PacketType.LL16
Expand All @@ -267,4 +326,5 @@ def main(
packet_type,
args.n_iters,
args.n_graph_iters,
args.split_mask,
)
Loading
Loading