-
Notifications
You must be signed in to change notification settings - Fork 16
add contiguous inside rmsnorm kernel #95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add contiguous inside rmsnorm kernel #95
Conversation
Signed-off-by: Kunshang Ji <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds input tensor contiguity checks and enforcement directly within the RMSnorm kernel implementation, aligning with a similar change made in vLLM. The change moves contiguity handling from Python to C++ level to prevent potential accuracy issues.
Key Changes:
- Adds contiguity validation for output and weight tensors
- Implements automatic conversion to contiguous layout for input tensor when stride is not 1
- Adds runtime checks to ensure proper tensor layouts before kernel execution
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if (input.stride(-1) != 1) { | ||
| input = input.contiguous(); | ||
| } | ||
| TORCH_CHECK(input.stride(-1) == 1); |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check at line 211 is redundant. After the contiguous() call on line 209, the stride(-1) is guaranteed to be 1. This assertion will never fail and can be removed.
| TORCH_CHECK(input.stride(-1) == 1); |
| if (input.stride(-1) != 1) { | ||
| input = input.contiguous(); | ||
| } | ||
| TORCH_CHECK(input.stride(-1) == 1); | ||
| TORCH_CHECK(weight.is_contiguous()); | ||
| VLLM_DISPATCH_FLOATING_TYPES( | ||
| input.scalar_type(), "call_rms_norm_kernel", [&] { | ||
| vllm::call_rms_norm_kernel<scalar_t>(out, input, weight, epsilon); |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modifying the input parameter silently may lead to unexpected behavior for callers. Consider either making input non-const to signal this modification is possible, or document this behavior clearly, or make a copy to avoid mutating the caller's reference.
| if (input.stride(-1) != 1) { | |
| input = input.contiguous(); | |
| } | |
| TORCH_CHECK(input.stride(-1) == 1); | |
| TORCH_CHECK(weight.is_contiguous()); | |
| VLLM_DISPATCH_FLOATING_TYPES( | |
| input.scalar_type(), "call_rms_norm_kernel", [&] { | |
| vllm::call_rms_norm_kernel<scalar_t>(out, input, weight, epsilon); | |
| auto input_ = input; | |
| if (input_.stride(-1) != 1) { | |
| input_ = input_.contiguous(); | |
| } | |
| TORCH_CHECK(input_.stride(-1) == 1); | |
| TORCH_CHECK(weight.is_contiguous()); | |
| VLLM_DISPATCH_FLOATING_TYPES( | |
| input_.scalar_type(), "call_rms_norm_kernel", [&] { | |
| vllm::call_rms_norm_kernel<scalar_t>(out, input_, weight, epsilon); |
|
Cuda may meet the same problem. |
baodii
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED.
Purpose
add contiguous inside rmsnorm kernel, vllm also add this in kernel in vllm-project/vllm#28103, and remove contiguous call in python level, this may bring potential accuracy issue.
Test Plan
Test Result
(Optional) Documentation Update
BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing (anything written below this line will be removed by GitHub Actions)