Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions csrc/layernorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ void rms_norm(
torch::Tensor& input,
torch::Tensor& weight,
double epsilon) {
TORCH_CHECK(out.is_contiguous());
if (input.stride(-1) != 1) {
input = input.contiguous();
}
TORCH_CHECK(input.stride(-1) == 1);
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check at line 211 is redundant. After the contiguous() call on line 209, the stride(-1) is guaranteed to be 1. This assertion will never fail and can be removed.

Suggested change
TORCH_CHECK(input.stride(-1) == 1);

Copilot uses AI. Check for mistakes.
TORCH_CHECK(weight.is_contiguous());
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "call_rms_norm_kernel", [&] {
vllm::call_rms_norm_kernel<scalar_t>(out, input, weight, epsilon);
Comment on lines +208 to 215
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modifying the input parameter silently may lead to unexpected behavior for callers. Consider either making input non-const to signal this modification is possible, or document this behavior clearly, or make a copy to avoid mutating the caller's reference.

Suggested change
if (input.stride(-1) != 1) {
input = input.contiguous();
}
TORCH_CHECK(input.stride(-1) == 1);
TORCH_CHECK(weight.is_contiguous());
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "call_rms_norm_kernel", [&] {
vllm::call_rms_norm_kernel<scalar_t>(out, input, weight, epsilon);
auto input_ = input;
if (input_.stride(-1) != 1) {
input_ = input_.contiguous();
}
TORCH_CHECK(input_.stride(-1) == 1);
TORCH_CHECK(weight.is_contiguous());
VLLM_DISPATCH_FLOATING_TYPES(
input_.scalar_type(), "call_rms_norm_kernel", [&] {
vllm::call_rms_norm_kernel<scalar_t>(out, input_, weight, epsilon);

Copilot uses AI. Check for mistakes.
Expand Down