- 
                Notifications
    
You must be signed in to change notification settings  - Fork 594
 
Description
Bug description
Summary
when training with torchtitan, If the torch._grouped_mm input tensor size is big enough, torch.AcceleratorError: CUDA error: an illegal memory access was encountered error occurs.
Specifically when training llm with torchtitan, I set
- local batch size = 16
 - sequence length = 4096
 - moe num activated experts = 8
 - hidden dim = 4096
 
then after some padding applied by torchtitan, the shape of the input tensor to torch._grouped_mm is [525312, 4096] and the error occurs. It works fine when I decrease any of these 4 factors(local batch. size, sequence length, moe num activated experts, hidden dim).
From what I remember, this error wasn't present before update around late september.
This error happens even when n_layers=1.
Stacktrace
Root Cause (first observed failure):
[0]:
  time      : 2025-10-24_18:38:53
  host      : Slurm-GPU-Node-44
  rank      : 5 (local_rank: 5)
  exitcode  : 1 (pid: 3608571)
  error_file: /tmp/torchelastic_gqd1bzn0/none_uf9e6gfx/attempt_0/5/error.json
  traceback : Traceback (most recent call last):
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 362, in wrapper
      return f(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^
    File "/data/fsx/private/yoonsoo/pr/torchtitan/torchtitan/train.py", line 612, in train
      self.train_step(data_iterator)
    File "/data/fsx/private/yoonsoo/pr/torchtitan/torchtitan/train.py", line 512, in train_step
      loss = self.forward_backward_step(input_dict, labels)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/fsx/private/yoonsoo/pr/torchtitan/torchtitan/train.py", line 488, in forward_backward_step
      pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1780, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1886, in _call_impl
      return inner()
             ^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1834, in inner
      result = forward_call(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/fsx/private/yoonsoo/pr/torchtitan/torchtitan/models/llama4/model/model.py", line 560, in forward
      h = layer(h, self.freqs_cis, attention_masks)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 433, in __call__
      return super().__call__(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1780, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1886, in _call_impl
      return inner()
             ^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1834, in inner
      result = forward_call(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 912, in compile_wrapper
      return fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1780, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1791, in _call_impl
      return forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 145, in forward
      def forward(self, *args, **kwargs):
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1129, in _fn
      return fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1139, in forward
      return compiled_fn(full_args)
             ^^^^^^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 343, in runtime_wrapper
      all_outs = call_func_at_runtime_with_args(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 133, in call_func_at_runtime_with_args
      out = normalize_as_list(f(args))
                              ^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 107, in g
      return f(*args)
             ^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/autograd/function.py", line 583, in apply
      return super().apply(*args, **kwargs)  # type: ignore[misc]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2152, in forward
      fw_outs = call_func_at_runtime_with_args(
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 133, in call_func_at_runtime_with_args
      out = normalize_as_list(f(args))
                              ^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 531, in wrapper
      return compiled_fn(runtime_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 695, in inner_fn
      unwrapped_outs = compiled_fn(unwrapped_args)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 729, in inner_fn
      outs = compiled_fn(args)
             ^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_inductor/output_code.py", line 618, in __call__
      return self.current_callable(inputs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_inductor/utils.py", line 3059, in run
      out = model(new_inputs)
            ^^^^^^^^^^^^^^^^^
    File "/tmp/torchinductor_yoonsoo/dl/cdlwx2fwmcdscf2xtltwynghxgi2parcmk5iam23dqpnnrd6ifv6.py", line 2145, in call
      triton_poi_fused_index_put_18.run(buf61, buf68, buf69, 3225419776, stream=stream5)
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 1335, in run
      self.autotune_to_one_config(*args, **kwargs)
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 1097, in autotune_to_one_config
      timings = self.benchmark_all_configs(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 1060, in benchmark_all_configs
      launcher: self.bench(launcher, *args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 922, in bench
      return benchmarker.benchmark_gpu(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py", line 89, in wrapper
      return fn(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py", line 320, in benchmark_gpu
      torch.cuda.synchronize()
    File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/cuda/__init__.py", line 1094, in synchronize
      return torch._C._cuda_synchronize()
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  torch.AcceleratorError: CUDA error: an illegal memory access was encountered
  Search for `cudaErrorIllegalAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
  CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
  For debugging consider passing CUDA_LAUNCH_BLOCKING=1
  Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Versions
How to reproduce
1. Install
conda create -y -n torchtitan-tmp python=3.12
conda activate torchtitan-tmp
git clone https://github.com/pytorch/torchtitan
cd torchtitan
pip install -r requirements.txt
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 --force-reinstall
python scripts/download_hf_assets.py --repo_id meta-llama/Llama-4-Scout-17B-16E --assets tokenizer --hf_token=...2. Create torchtitan config
locate following train config at torchtitan/models/llama4/train_configs/tmp.toml
[job]
dump_folder = "./outputs"
description = "Llama 4 Scout 17Bx16E training"
[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100
[metrics]
log_freq = 10
enable_tensorboard = false
save_tb_folder = "tb"
[model]
name = "llama4"
flavor = "tmp"
hf_assets_path = "./assets/hf/Llama-4-Scout-17B-16E"
# converters = ["quantize.linear.float8"]
[optimizer]
name = "AdamW"
lr = 4e-3
eps = 1e-15
[lr_scheduler]
warmup_steps = 600
min_lr_factor = 0.1
[training]
local_batch_size = 16
seq_len = 4096
max_norm = 1.0  # grad norm clipping
steps = 3000
dataset = "c4"
[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
enable_async_tensor_parallel = false
pipeline_parallel_degree = 1
context_parallel_degree = 1
expert_parallel_degree = 1
expert_tensor_parallel_degree = 1
[checkpoint]
enable = false
folder = "checkpoint"
interval = 500
last_save_model_only = true
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
[activation_checkpoint]
mode = "full" # ["none", "selective", "full"]
[compile]
enable=true
components = ["model", "loss"]
[quantize.linear.float8]
enable_fsdp_float8_all_gather = true
precompute_float8_dynamic_scale_for_fsdp = true
filter_fqns = ["output", "router.gate"]
[quantize.linear.mx]
filter_fqns = ["output", "router.gate"]3. Add args
add following args to torchtitan/models/llama4/__init__.py
    "tmp": TransformerModelArgs(
        dim=4096,
        n_layers=1,
        n_heads=40,
        n_kv_heads=8,
        ffn_dim_multiplier=1.2,
        multiple_of=2048,
        rope_theta=500000,
        rope_scaling_args=RoPEScalingArgs(),
        max_seq_len=10485760,
        moe_args=MoEArgs(num_experts=128, top_k=8),
        interleave_moe_layer_step=1,
    ),4. Run
torchrun --nproc_per_node=8 --rdzv_backend=c10d -m torchtitan.train --job.config_file torchtitan/models/llama4/train_configs/tmp.tomlEnv
Collecting environment information...
PyTorch version: 2.10.0.dev20251023+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.12.12 | packaged by conda-forge | (main, Oct 22 2025, 23:25:55) [GCC 14.3.0] (64-bit runtime)
Python platform: Linux-6.8.0-1029-aws-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.8.93
CUDA_MODULE_LOADING set to: 
GPU models and configuration: 
GPU 0: NVIDIA H200
GPU 1: NVIDIA H200
GPU 2: NVIDIA H200
GPU 3: NVIDIA H200
GPU 4: NVIDIA H200
GPU 5: NVIDIA H200
GPU 6: NVIDIA H200
GPU 7: NVIDIA H200
Nvidia driver version: 570.158.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.12.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.12.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.12.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.12.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.12.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.12.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.12.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.12.0
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        48 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               96
On-line CPU(s) list:                  0-95
Vendor ID:                            AuthenticAMD
Model name:                           AMD EPYC 7R13 Processor
CPU family:                           25
Model:                                1
Thread(s) per core:                   1
Core(s) per socket:                   48
Socket(s):                            2
Stepping:                             1
BogoMIPS:                             5299.99
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save vaes vpclmulqdq rdpid
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            3 MiB (96 instances)
L1i cache:                            3 MiB (96 instances)
L2 cache:                             48 MiB (96 instances)
L3 cache:                             384 MiB (12 instances)
NUMA node(s):                         4
NUMA node0 CPU(s):                    0-23
NUMA node1 CPU(s):                    24-47
NUMA node2 CPU(s):                    48-71
NUMA node3 CPU(s):                    72-95
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Mitigation; Safe RET
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected
Versions of relevant libraries:
[pip3] numpy==2.3.4
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] pytorch-triton==3.5.0+git7416ffcb
[pip3] torch==2.10.0.dev20251023+cu128
[pip3] torchao==0.15.0.dev20251024+cu128
[pip3] torchdata==0.11.0
[pip3] triton==3.5.0
[conda] numpy                       2.3.4                     pypi_0              pypi
[conda] nvidia-cublas-cu12          12.8.4.1                  pypi_0              pypi
[conda] nvidia-cuda-cupti-cu12      12.8.90                   pypi_0              pypi
[conda] nvidia-cuda-nvrtc-cu12      12.8.93                   pypi_0              pypi
[conda] nvidia-cuda-runtime-cu12    12.8.90                   pypi_0              pypi
[conda] nvidia-cudnn-cu12           9.10.2.21                 pypi_0              pypi
[conda] nvidia-cufft-cu12           11.3.3.83                 pypi_0              pypi
[conda] nvidia-curand-cu12          10.3.9.90                 pypi_0              pypi
[conda] nvidia-cusolver-cu12        11.7.3.90                 pypi_0              pypi
[conda] nvidia-cusparse-cu12        12.5.8.93                 pypi_0              pypi
[conda] nvidia-cusparselt-cu12      0.7.1                     pypi_0              pypi
[conda] nvidia-nccl-cu12            2.27.5                    pypi_0              pypi
[conda] nvidia-nvjitlink-cu12       12.8.93                   pypi_0              pypi
[conda] nvidia-nvtx-cu12            12.8.90                   pypi_0              pypi
[conda] pytorch-triton              3.5.0+git7416ffcb         pypi_0              pypi
[conda] torch                       2.10.0.dev20251023+cu128  pypi_0              pypi
[conda] torchao                     0.15.0.dev20251024+cu128  pypi_0              pypi
[conda] torchdata                   0.11.0                    pypi_0              pypi
[conda] triton                      3.5.0                     pypi_0              pypi