Skip to content

[amd] fix out of resources issue related to turning on async_copy#808

Merged
xuzhao9 merged 6 commits intometa-pytorch:mainfrom
scxiao:scxiao/fix_out_of_resources_issue
Feb 13, 2026
Merged

[amd] fix out of resources issue related to turning on async_copy#808
xuzhao9 merged 6 commits intometa-pytorch:mainfrom
scxiao:scxiao/fix_out_of_resources_issue

Conversation

@scxiao
Copy link
Copy Markdown
Contributor

@scxiao scxiao commented Jan 21, 2026

Turning on async copy will need more lds usage, so some configurations may cause lds out of resource. Specifically, lds usage increases from (num_stages - 1) * usage_per_iter to num_stages * usage_per_iter. This PR fixes that by decreasing the num_stages by 1 so avoid the out of resources.

@meta-cla meta-cla Bot added the cla signed label Jan 21, 2026
@xuzhao9 xuzhao9 changed the title fix out of resources issue related to turning on async_copy [amd] fix out of resources issue related to turning on async_copy Jan 21, 2026
@scxiao scxiao temporarily deployed to docker-s3-upload January 21, 2026 17:39 — with GitHub Actions Inactive
@scxiao scxiao temporarily deployed to docker-s3-upload January 21, 2026 18:18 — with GitHub Actions Inactive
Copy link
Copy Markdown
Contributor

@njriasan njriasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One minor question regarding if we should check the actual Triton knob, but otherwise this looks good to me.

Comment thread tritonbench/operators/fp8_gemm/persistent.py Outdated
@xuzhao9
Copy link
Copy Markdown
Contributor

xuzhao9 commented Jan 21, 2026

Can you help check the CI failure? It seems addmm, gemm, and swiglu are still failing.

@scxiao
Copy link
Copy Markdown
Contributor Author

scxiao commented Jan 21, 2026

Can you help check the CI failure? It seems addmm, gemm, and swiglu are still failing.

The problem in these three tests seem not related to out of shared memory, but it is related to async copy. They are the same problem and a little weird. Every time I re-compile triton, these errors can be reproduced, and I can see gpucore.**** dumps. Seems like it is a memory access violation problem.

On the other hand, If I do NOT set async_copy, it can pass. Once it is passed, when I reran with async_copy enabled, it can still pass until I recompile triton. Still looking into the issue. will update you later.

@scxiao
Copy link
Copy Markdown
Contributor Author

scxiao commented Jan 22, 2026

Can you help check the CI failure? It seems addmm, gemm, and swiglu are still failing.

The problem in these three tests seem not related to out of shared memory, but it is related to async copy. They are the same problem and a little weird. Every time I re-compile triton, these errors can be reproduced, and I can see gpucore.**** dumps. Seems like it is a memory access violation problem.

On the other hand, If I do NOT set async_copy, it can pass. Once it is passed, when I reran with async_copy enabled, it can still pass until I recompile triton. Still looking into the issue. will update you later.

Here is the error message

after, kernel_call, z_shape = torch.Size([20120, 512])
before, kernel_call
after, kernel_call, z_shape = torch.Size([20120, 512])
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py:3584: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  current_out_size = out_base.storage().size()
Memory access fault by GPU node-1 (Agent handle: 0x55e79b891d70) on address 0x7f1d4bf89000. Reason: Unknown.
GPU core dump created: gpucore.1606

The GPU memory access violation happens after the kernel execution. seems like related to torch compile. Do you have any idea what it is wrong here?

@scxiao
Copy link
Copy Markdown
Contributor Author

scxiao commented Jan 22, 2026

For the gemm_test, here is the output error:

>>>>--------------
WARNING:tritonbench.utils.triton_op:First-k mode: Selected 1 sequential inputs starting from index 0 (total available: 12)
WARNING:tritonbench.utils.triton_op:Input IDs to run: [0]
loc1
  0%|                                                                                                                                               | 0/1 [00:00<?, ?it/s]WARNING:tritonbench.utils.triton_op:Running input ID 0:
(M, N, K)
-----------------
(8192, 8192, 512)
INFO:tritonbench.utils.triton_op:Took 0.01ms to get benchmark function for aten_matmul
INFO:tritonbench.utils.triton_op:Took 0.00ms to get benchmark function for triton_tutorial_matmul
INFO:tritonbench.utils.triton_op:Took 0.33ms to get benchmark function for matmul_partition_k
INFO:tritonbench.utils.triton_op:Took 841.92ms to get benchmark function for aten_tunableop_matmul
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py:3584: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  current_out_size = out_base.storage().size()
GPU core dump created: gpucore.258
Kernel Name: triton_mm
VGPU=0x55d93e589ce0 SWq=0x7f3404bf8000, HWq=0x7f2be1e00000, id=1
        Dispatch Header =0xb02 (type=2, barrier=1, acquire=1, release=1), setup=0
        grid=[8388608, 1, 1], workgroup=[256, 1, 1]
        private_seg_size=0, group_seg_size=50624
        kernel_obj=0x7f2be14b85c0, kernarg_address=0x0x7f2be1c00000
        completion_signal=0x0, correlation_id=0
        rptr=36506, wptr=36508
:0:rocdevice.cpp            :3587: 1368043037972 us:  Callback: Queue 0x7f2be1e00000 aborting with error : HSA_STATUS_ERROR_MEMORY_APERTURE_VIOLATION: The agent attempted to access memory beyond the largest legal address. code: 0x29
Aborted (core dumped)

We can see the kernel name is triton_mm, I search tritonbench, but did not get this kernel, do you know where this kernel is from? The test swiglu also reports this same error.

@xuzhao9
Copy link
Copy Markdown
Contributor

xuzhao9 commented Jan 22, 2026

Hi @scxiao, I believe triton_mm is generated by pytorch compiler. You could run the command with TORCH_TRACE=<trace_dir> python run.py ... to obtain the torch traces, then run tlparse <trace_dir> to get the pt2 generated triton code.

@scxiao
Copy link
Copy Markdown
Contributor Author

scxiao commented Jan 27, 2026

Hi @scxiao, I believe triton_mm is generated by pytorch compiler. You could run the command with TORCH_TRACE=<trace_dir> python run.py ... to obtain the torch traces, then run tlparse <trace_dir> to get the pt2 generated triton code.

some update:
The Issue is the input tensors do not have the attribute tt.pointer_range = 32 : i32 in the torch inductor case, we can manually remove this attribute in triton compiler to reproduce the problem. and we are working on fixing that now.

@xuzhao9
Copy link
Copy Markdown
Contributor

xuzhao9 commented Feb 6, 2026

Thanks for looking into it, @scxiao !

cc @karthickai @njriasan , Triton AMD async copy update does not play well with torchinductor, should we consider adding tt.pointer_range = 32 : i32 in inductor codegen?

@scxiao
Copy link
Copy Markdown
Contributor Author

scxiao commented Feb 6, 2026

Thanks for looking into it, @scxiao !

cc @karthickai @njriasan , Triton AMD async copy update does not play well with torchinductor, should we consider adding tt.pointer_range = 32 : i32 in inductor codegen?

We are working on that now and almost done for that. The problem is in the llvm backend side and the fixed has been landed. BTW I would expect adding tt.pointer_range = 32 : i32 in inductor codegen will help performance in many places since this attribute can guarantee the usage of buffer ops and a few optimizations. But I am not sure whether we can add that in codegen.

@scxiao
Copy link
Copy Markdown
Contributor Author

scxiao commented Feb 12, 2026

Hi @xuzhao9, with this PR: triton-lang/triton#9431 merged, all failed tests (memory access violation) will pass. Let us wait this one landed, then we can resume this PR. Thanks

@AlexAUT
Copy link
Copy Markdown

AlexAUT commented Feb 12, 2026

FYI the PR has landed

@scxiao scxiao force-pushed the scxiao/fix_out_of_resources_issue branch from 5ade1e0 to fe0401a Compare February 12, 2026 19:48
@scxiao scxiao temporarily deployed to docker-s3-upload February 12, 2026 20:02 — with GitHub Actions Inactive
@scxiao scxiao temporarily deployed to docker-s3-upload February 12, 2026 20:02 — with GitHub Actions Inactive
@xuzhao9
Copy link
Copy Markdown
Contributor

xuzhao9 commented Feb 12, 2026

It looks like there are still 3 failures left (addmm, gemm, swiglu), can you help take another look? @scxiao

Nvm, just realized that we need to update the testing Docker image. I am working on that.

@scxiao
Copy link
Copy Markdown
Contributor Author

scxiao commented Feb 12, 2026

Hi @xuzhao9, is the CI build uses the latest upstream triton? If not, how can we set to that? Thanks

@xuzhao9
Copy link
Copy Markdown
Contributor

xuzhao9 commented Feb 13, 2026

@scxiao it updates upstream triton every week, but we can manually kick off run to update that.

@scxiao scxiao temporarily deployed to docker-s3-upload February 13, 2026 17:34 — with GitHub Actions Inactive
@xuzhao9
Copy link
Copy Markdown
Contributor

xuzhao9 commented Feb 13, 2026

The Docker has updated and this PR has fixed all test failures. Thanks @scxiao a lot for your contribution!

@xuzhao9 xuzhao9 merged commit 139aa25 into meta-pytorch:main Feb 13, 2026
9 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants