Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions collector/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,13 @@ def worker(queue, device_id: int, func, progress_value, lock, error_queue=None,
setup_signal_handlers(device_id, error_queue)

# Setup device
device = torch.device(f"cuda:{device_id}")
torch.cuda.set_device(device_id)

if torch.cuda.is_available():
device = torch.device(f"cuda:{device_id}")
torch.cuda.set_device(device)
elif torch.xpu.is_available():
device = torch.device(f"xpu:{device_id}")
torch.xpu.set_device(device)
worker_logger.info(f"Worker {device_id} initialized for {module_name}")

# Process tasks
Expand Down Expand Up @@ -549,6 +554,7 @@ def collect_vllm(num_processes: int, ops: list[str] | None = None):
"get_func": "get_moe_test_cases",
"run_func": "run_moe_torch",
},
# TODO sihan: recheck whether MLA supported on XPU
{
"name": "vllm",
"type": "mla_context",
Expand Down Expand Up @@ -777,6 +783,9 @@ def main():
setup_logging(debug=args.debug)

num_processes = torch.cuda.device_count()
if num_processes == 0:
if torch.xpu.is_available():
num_processes = torch.xpu.device_count()
logger.info(f"Starting collection with {num_processes} GPU processes")

# Set environment variables for worker processes
Expand Down
63 changes: 42 additions & 21 deletions collector/collect_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,14 @@
from typing import Optional

import torch

import torch.distributed as dist
from helper import PowerMonitor, log_perf

def sync_device():
if torch.cuda.is_available():
torch.cuda.synchronize()
elif torch.xpu.is_available():
torch.xpu.synchronize()

def get_input_shape_and_comm_size(size, token_dim=4096):
"""Convert size to appropriate input shape for AllReduce operations"""
Expand Down Expand Up @@ -271,7 +276,10 @@ def setup_vllm_distributed(world_size, rank, use_slurm):
local_rank = int(os.environ.get("LOCAL_RANK", str(rank)))

# Set CUDA device
torch.cuda.set_device(local_rank)
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
elif hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.set_device(local_rank)

# Initialize distributed environment
if not torch.distributed.is_initialized():
Expand All @@ -289,12 +297,17 @@ def setup_vllm_distributed(world_size, rank, use_slurm):
print(f" Local rank: {local_rank}")

try:
if torch.xpu.is_available():
dist.init_process_group(
backend="xccl",
init_method="env://",
)
vllm_mods["init_distributed_environment"](
world_size=world_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=local_rank,
backend="nccl",
backend="xccl" if torch.xpu.is_available() else "nccl",
)
except Exception as e:
print(f"\nERROR: Failed to initialize distributed environment: {e}")
Expand Down Expand Up @@ -332,21 +345,25 @@ def benchmark_vllm_allreduce(
num_runs = 20

# Warmup communication
warmup_tensor = torch.ones(1, dtype=torch_dtype, device="cuda")
warmup_tensor = torch.ones(1, dtype=torch_dtype, device="cuda" if torch.cuda.is_available() else "xpu")
_ = vllm_mods["tensor_model_parallel_all_reduce"](warmup_tensor)
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.synchronize()
elif hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.synchronize()

size = min_size
while size < max_size:
input_shape = get_input_shape_and_comm_size(size)

# Test both graph capture and eager mode
for use_graph in [True, False]:
flag_lst = [False] if torch.xpu.is_available() else [True, False]
for use_graph in [flag_lst]:
mode_str = "graph" if use_graph else "eager"

if use_graph:
# Graph capture mode
with vllm_mods["graph_capture"](device=torch.cuda.current_device()) as graph_capture_context:
with vllm_mods["graph_capture"](device=torch.xpu.current_device()) as graph_capture_context:
# Create input tensors
input_tensors = []
for _ in range(repeat_n):
Expand Down Expand Up @@ -381,10 +398,10 @@ def benchmark_vllm_allreduce(
actual_num_runs = min(actual_num_runs, 1000)
else:
# Normal warmup
torch.cuda.synchronize()
torch.xpu.synchronize()
for i in range(num_warmups):
graph.replay()
torch.cuda.synchronize()
torch.xpu.synchronize()

# Initialize power monitoring
power_monitor = None
Expand All @@ -395,22 +412,22 @@ def benchmark_vllm_allreduce(
power_monitor = None

# Timing
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)

start_event.record()
for i in range(actual_num_runs):
graph.replay()
end_event.record()
torch.cuda.synchronize()
torch.xpu.synchronize()

# Stop power monitoring
if power_monitor:
power_stats = power_monitor.stop_sampling()

else:
# Eager mode
input_tensor = torch.ones(input_shape, dtype=torch_dtype, device="cuda")
input_tensor = torch.ones(input_shape, dtype=torch_dtype, device="xpu")

# Adaptive num_runs calculation for power measurement
actual_num_runs = num_runs
Expand All @@ -432,11 +449,11 @@ def benchmark_vllm_allreduce(
actual_num_runs = min(actual_num_runs, 1000)
else:
# Normal warmup
torch.cuda.synchronize()
sync_device()
for _ in range(num_warmups):
for _ in range(repeat_n):
_ = vllm_mods["tensor_model_parallel_all_reduce"](input_tensor.clone())
torch.cuda.synchronize()
sync_device()

# Initialize power monitoring
power_monitor = None
Expand All @@ -447,15 +464,19 @@ def benchmark_vllm_allreduce(
power_monitor = None

# Timing
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
if torch.cuda.is_available():
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
elif torch.xpu.is_available():
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)

start_event.record()
for _ in range(actual_num_runs):
for _ in range(repeat_n):
_ = vllm_mods["tensor_model_parallel_all_reduce"](input_tensor.clone())
end_event.record()
torch.cuda.synchronize()
sync_device()

# Stop power monitoring
if power_monitor:
Expand Down Expand Up @@ -488,7 +509,7 @@ def benchmark_vllm_allreduce(
],
framework="vLLM",
version=vllm_version,
device_name=torch.cuda.get_device_name(),
device_name=torch.xpu.get_device_name() if torch.xpu.is_available() else torch.cuda.get_device_name(),
op_name="all_reduce",
kernel_source=f"vLLM_custom_{mode_str}",
perf_filename=perf_filename,
Expand All @@ -498,7 +519,7 @@ def benchmark_vllm_allreduce(
size *= ratio

# Synchronize all ranks before cleanup
torch.cuda.synchronize()
sync_device()
if torch.distributed.is_initialized():
torch.distributed.barrier()

Expand Down Expand Up @@ -576,7 +597,7 @@ def allreduce_benchmark(
parser.add_argument(
"--range",
"-r",
default="128,1073741824,2", # 128B to 1024MB
default="128,268435456,2" if torch.xpu.is_available() else "128,1073741824,2", # 128B to 1024MB
help="min_size,max_size,multiplicative_ratio",
)
parser.add_argument("--use-slurm", action="store_true", help="Use SLURM environment variables")
Expand Down
22 changes: 12 additions & 10 deletions collector/collect_comm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,19 @@ num_gpus_nccl=(2 4 8)
nccl_ops=("all_gather" "alltoall" "reduce_scatter" "all_reduce")
dtypes=("half" "int8")

# Please comment below for XPU
# TODO add oneCCL perf collector
for n in "${num_gpus_nccl[@]}"; do
for op in "${nccl_ops[@]}"; do
for dtype in "${dtypes[@]}"; do
if [[ "$measure_power" == "true" ]]; then
python3 collect_nccl.py -n "$n" -NCCL "$op" --dtype "$dtype" \
--measure_power --power_test_duration_sec "$power_test_duration"
else
python3 collect_nccl.py -n "$n" -NCCL "$op" --dtype "$dtype"
fi
done
done
for op in "${nccl_ops[@]}"; do
for dtype in "${dtypes[@]}"; do
if [[ "$measure_power" == "true" ]]; then
python3 collect_nccl.py -n "$n" -NCCL "$op" --dtype "$dtype" \
--measure_power --power_test_duration_sec "$power_test_duration"
else
python3 collect_nccl.py -n "$n" -NCCL "$op" --dtype "$dtype"
fi
done
done
done

echo "Running AllReduce Benchmarks with $all_reduce_backend backend..."
Expand Down
73 changes: 40 additions & 33 deletions collector/common_test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dataclasses
import itertools
from typing import Optional

import torch

@dataclasses.dataclass
class MoeCommonTestCase:
Expand Down Expand Up @@ -31,26 +31,26 @@ def get_common_moe_test_cases():
64,
80,
96,
128,
160,
192,
256,
320,
384,
512,
768,
1024,
1536,
2048,
3072,
4096,
6144,
8192,
12288,
16384,
20480,
32768,
65536,
# 128, # crash here? FIXME check why exceed 24GB
# 160,
# 192,
# 256,
# 320,
# 384,
# 512,
# 768,
# 1024,
# 1536,
# 2048,
# 3072,
# 4096,
# 6144,
# 8192,
# 12288,
# 16384,
# 20480,
# 32768,
# 65536,
]
tp_list = [1, 2, 4, 8, 16, 32]
ep_list = [1, 2, 4, 8, 16, 32, 64, 128, 256]
Expand All @@ -70,17 +70,21 @@ def get_common_moe_test_cases():
# [2048,1408,4,60], #qwen1.5_moe
# [2048,1408,6,64], #deepseekv1_moe
# [5120,1536,6,160], #deepseekv2
model_config_list = [
[4096, 14336, 2, 8, "MOE_Mixtral8x7B"], # mixtral_8x7b
[6144, 16384, 2, 8, "MOE_Mixtral8x22B"], # mixtral_8x22b
[7168, 2048, 8, 256, "DEEPSEEK_V3"], # deepseekv3, will have 1 shared expert
[2048, 768, 8, 128, "QWEN3_30B_A3B"], # qwen3-moe, 30b-a3b
[4096, 1536, 8, 128, "QWEN3_235B"], # qwen3-moe, 235b-a22b
[6144, 2560, 8, 160, "QWEN3_480B"], # qwen3-moe, 480b-a35b
[7168, 2048, 8, 384, "KIMI_K2"], # kimi k2
[2880, 2880, 4, 128, "GPT_OSS_120B"],
[2880, 2880, 4, 32, "GPT_OSS_20B"],
]
if torch.cuda.is_available():
model_config_list = [
[4096, 14336, 2, 8, "MOE_Mixtral8x7B"], # mixtral_8x7b
[6144, 16384, 2, 8, "MOE_Mixtral8x22B"], # mixtral_8x22b
[7168, 2048, 8, 256, "DEEPSEEK_V3"], # deepseekv3, will have 1 shared expert
[2048, 768, 8, 128, "QWEN3_30B_A3B"], # qwen3-moe, 30b-a3b
[4096, 1536, 8, 128, "QWEN3_235B"], # qwen3-moe, 235b-a22b
[6144, 2560, 8, 160, "QWEN3_480B"], # qwen3-moe, 480b-a35b
[7168, 2048, 8, 384, "KIMI_K2"], # kimi k2
[2880, 2880, 4, 128, "GPT_OSS_120B"],
[2880, 2880, 4, 32, "GPT_OSS_20B"],
]
elif torch.xpu.is_available():
# FIXME sihan: firstly try this on xpu
model_config_list = [[2048,1408,4,60,"QWEN1.5_MOE"]] #qwen1.5_moe

test_cases: list[MoeCommonTestCase] = []

Expand Down Expand Up @@ -183,7 +187,10 @@ def get_gemm_common_test_cases() -> list[GemmCommonTestCase]:
10240,
12288,
]
nk_list_ext = [16384, 65536] # for coverage and interp purpose
if torch.cuda.is_available():
nk_list_ext = [16384, 65536] # for coverage and interp purpose
elif torch.xpu.is_available(): # FIXME narrow down the search space for available xpu
nk_list_ext = []

test_cases = []
# x_list_orig+add+ext <==> nk_list+ext
Expand Down
Loading
Loading