-
Notifications
You must be signed in to change notification settings - Fork 744
Open
Description
🐛 Bug
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered in _xformers_tiled_matmul_kernel when processing large matrices in _launch_triton_matmul
It's very easy to get this error when you train with tp+sp on a long sequence with a big Vocab.
To Reproduce
Steps to reproduce the behavior:
import os
import torch
import torch.distributed as dist
from xformers.ops.seqpar import sequence_parallel_leading_matmul
def main():
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
local_rank = int(os.environ.get("LOCAL_RANK", rank % max(1, torch.cuda.device_count())))
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
seq_len_global = 65536 * 2
hidden_in = 2048
hidden_out = 64512
dtype = torch.bfloat16
seq_len_local = seq_len_global // world_size
x = torch.randn(seq_len_local, hidden_in, device=device, dtype=dtype, requires_grad=True)
w = torch.randn(hidden_in, hidden_out, device=device, dtype=dtype, requires_grad=True)
(y,) = sequence_parallel_leading_matmul(x, [w], fuse=True, process_group=dist.group.WORLD)
y.mean().backward()
if __name__ == "__main__":
main()command
CUDA_LAUNCH_BLOCKING=1 CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 reproduce_error.py
Environment
xFormers 0.0.30+56be3b5.d20250810
pytorch.version: 2.7.0a0+79aa17489c.nv25.04
build.python_version: 3.12.3
build.torch_version: 2.7.0a0+79aa17489c.nv25.04
build.nvcc_version: 12.9.41
pytorch-triton==3.2.0+git4b3bb1f8b.nvinternal
Solution
Metadata
Metadata
Assignees
Labels
No labels