Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/assert.h"
#include <NvInferRuntime.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>

#include "tensorrt_llm/common/config.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/quantization.h"
#include "tensorrt_llm/runtime/ipcUtils.h"

TRTLLM_NAMESPACE_BEGIN

namespace kernels::minimax_ar
{
template <typename DType>
struct ElemsPerAccess;

template <>
struct ElemsPerAccess<half>
{
static constexpr int value = 8;
using norm_weight_type = common::__nv_bfloat168;
};

template <>
struct ElemsPerAccess<nv_bfloat16>
{
static constexpr int value = 8;
using norm_weight_type = common::__nv_bfloat168;
};

template <>
struct ElemsPerAccess<float>
{
static constexpr int value = 4;
using norm_weight_type = common::__nv_bfloat164;
};

template <typename DType>
static constexpr int kElemsPerAccess = ElemsPerAccess<DType>::value;

struct MiniMaxReduceRMSParams
{
int nranks{};
int rank{};
nvinfer1::DataType dtype;
int size_q{}; // numel of Q (num_token * head_dim_q)
int hidden_dim{}; // head_dim_q
int size_k{}; // numel of K (num_token * head_dim_k)
int hidden_dim_k{}; // head_dim_k; must have head_dim_q >= head_dim_k
void** workspace{};
void* allreduce_in{}; // Q input
void* rms_norm_out{}; // Q output
void* rms_gamma{}; // Q norm weight
void* allreduce_in_k{}; // K input (nullptr for single-matrix path)
void* rms_norm_out_k{}; // K output
void* rms_gamma_k{}; // K norm weight
float rms_eps{};
cudaStream_t stream{};
bool trigger_completion_at_end = true;
};

void minimax_reduce_rms_op(MiniMaxReduceRMSParams const& params);

} // namespace kernels::minimax_ar

TRTLLM_NAMESPACE_END
115 changes: 114 additions & 1 deletion cpp/tensorrt_llm/thop/allreduceOp.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
* SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -23,6 +23,7 @@
#include "tensorrt_llm/common/ncclUtils.h"
#include "tensorrt_llm/common/nvmlWrapper.h"
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h"
#include "tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h"
#include "tensorrt_llm/kernels/communicationKernels/customLowPrecisionAllReduceKernels.h"
#include "tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.h"
Expand Down Expand Up @@ -1822,6 +1823,96 @@ std::vector<torch::Tensor> mnnvlFusionAllReduce(torch::Tensor& input, torch::opt
return {output, residualOut};
}

torch::Tensor minimax_allreduce_rms(torch::Tensor const& input, torch::Tensor const& norm_weight,
torch::Tensor workspace, int64_t const rank, int64_t const nranks, double const eps,
bool const trigger_completion_at_end_)
{
TORCH_CHECK(input.dim() == 2, "minimax_allreduce_rms: input must be 2D");
TORCH_CHECK(norm_weight.dim() == 1, "minimax_allreduce_rms: norm_weight must be 1D");
TORCH_CHECK(
input.size(-1) == norm_weight.size(0), "minimax_allreduce_rms: input hidden dim must match norm_weight");
TORCH_CHECK(input.is_contiguous(), "minimax_allreduce_rms: input must be contiguous");
TORCH_CHECK(norm_weight.is_contiguous(), "minimax_allreduce_rms: norm_weight must be contiguous");
TORCH_CHECK(norm_weight.scalar_type() == torch::kBFloat16, "minimax_allreduce_rms: norm_weight must be bfloat16");

auto allreduce_params = tensorrt_llm::kernels::minimax_ar::MiniMaxReduceRMSParams();

allreduce_params.nranks = static_cast<int>(nranks);
allreduce_params.rank = static_cast<int>(rank);
allreduce_params.dtype = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type());
allreduce_params.size_q = static_cast<int>(input.numel());
allreduce_params.hidden_dim = static_cast<int>(input.size(-1));
allreduce_params.workspace = reinterpret_cast<void**>(workspace.mutable_data_ptr());
allreduce_params.allreduce_in = input.data_ptr();
// allreduce_params.rms_norm_out = nullptr;
allreduce_params.rms_gamma = norm_weight.data_ptr();
allreduce_params.rms_eps = static_cast<float>(eps);
allreduce_params.stream = at::cuda::getCurrentCUDAStream(input.get_device());

torch::Tensor rms_norm_out = torch::empty_like(input);
allreduce_params.rms_norm_out = rms_norm_out.mutable_data_ptr();
allreduce_params.trigger_completion_at_end = trigger_completion_at_end_;

tensorrt_llm::kernels::minimax_ar::minimax_reduce_rms_op(allreduce_params);

return rms_norm_out;
}

std::vector<torch::Tensor> minimax_allreduce_rms_qk(torch::Tensor const& q, torch::Tensor const& k,
torch::Tensor const& norm_weight_q, torch::Tensor const& norm_weight_k, torch::Tensor workspace, int64_t const rank,
int64_t const nranks, double const eps, bool const trigger_completion_at_end_)
{
int64_t constexpr kSupportedGlobalHeadDimQ = 6144;
int64_t constexpr kSupportedGlobalHeadDimK = 1024;

TORCH_CHECK(q.scalar_type() == k.scalar_type(), "minimax_allreduce_rms_qk: q and k must have same dtype");
TORCH_CHECK(q.dim() == 2 && k.dim() == 2, "minimax_allreduce_rms_qk: q and k must be 2D");
TORCH_CHECK(q.size(0) == k.size(0), "minimax_allreduce_rms_qk: q and k must have same num_token");
TORCH_CHECK(q.is_contiguous(), "minimax_allreduce_rms_qk: q must be contiguous");
TORCH_CHECK(k.is_contiguous(), "minimax_allreduce_rms_qk: k must be contiguous");
TORCH_CHECK(norm_weight_q.dim() == 1, "minimax_allreduce_rms_qk: norm_weight_q must be 1D");
TORCH_CHECK(norm_weight_k.dim() == 1, "minimax_allreduce_rms_qk: norm_weight_k must be 1D");
TORCH_CHECK(norm_weight_q.is_contiguous(), "minimax_allreduce_rms_qk: norm_weight_q must be contiguous");
TORCH_CHECK(norm_weight_k.is_contiguous(), "minimax_allreduce_rms_qk: norm_weight_k must be contiguous");
TORCH_CHECK(
norm_weight_q.scalar_type() == torch::kBFloat16, "minimax_allreduce_rms_qk: norm_weight_q must be bfloat16");
TORCH_CHECK(
norm_weight_k.scalar_type() == torch::kBFloat16, "minimax_allreduce_rms_qk: norm_weight_k must be bfloat16");
int64_t head_dim_q = q.size(-1);
int64_t head_dim_k = k.size(-1);
TORCH_CHECK(head_dim_q >= head_dim_k, "minimax_allreduce_rms_qk: head_dim_q must be >= head_dim_k");
TORCH_CHECK(head_dim_q == norm_weight_q.size(0), "minimax_allreduce_rms_qk: q hidden dim must match norm_weight_q");
TORCH_CHECK(head_dim_k == norm_weight_k.size(0), "minimax_allreduce_rms_qk: k hidden dim must match norm_weight_k");
TORCH_CHECK((head_dim_q * nranks) == kSupportedGlobalHeadDimQ && (head_dim_k * nranks) == kSupportedGlobalHeadDimK,
"minimax_allreduce_rms_qk: only global q/k dims 6144/1024 are currently supported");

auto params = tensorrt_llm::kernels::minimax_ar::MiniMaxReduceRMSParams();
params.nranks = static_cast<int>(nranks);
params.rank = static_cast<int>(rank);
params.dtype = tensorrt_llm::runtime::TorchUtils::dataType(q.scalar_type());
params.size_q = static_cast<int>(q.numel());
params.hidden_dim = static_cast<int>(head_dim_q);
params.size_k = static_cast<int>(k.numel());
params.hidden_dim_k = static_cast<int>(head_dim_k);
params.workspace = reinterpret_cast<void**>(workspace.mutable_data_ptr());
params.allreduce_in = q.data_ptr();
params.rms_gamma = norm_weight_q.data_ptr();
params.allreduce_in_k = k.data_ptr();
params.rms_gamma_k = norm_weight_k.data_ptr();
params.rms_eps = static_cast<float>(eps);
params.stream = at::cuda::getCurrentCUDAStream(q.get_device());
params.trigger_completion_at_end = trigger_completion_at_end_;

torch::Tensor rms_norm_out_q = torch::empty_like(q);
torch::Tensor rms_norm_out_k = torch::empty_like(k);
params.rms_norm_out = rms_norm_out_q.mutable_data_ptr();
params.rms_norm_out_k = rms_norm_out_k.mutable_data_ptr();

tensorrt_llm::kernels::minimax_ar::minimax_reduce_rms_op(params);

return {rms_norm_out_q, rms_norm_out_k};
}

} // namespace torch_ext

TRTLLM_NAMESPACE_END
Expand Down Expand Up @@ -1886,6 +1977,26 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
"int nranks,"
"float eps) -> Tensor[]");
m.def("preallocate_nccl_window_buffer(Tensor input, int[] group, int count) -> ()");
m.def(
"minimax_allreduce_rms("
"Tensor input,"
"Tensor norm_weight,"
"Tensor workspace,"
"int rank,"
"int nranks,"
"float eps,"
"bool trigger_completion_at_end) -> Tensor");
m.def(
"minimax_allreduce_rms_qk("
"Tensor q,"
"Tensor k,"
"Tensor norm_weight_q,"
"Tensor norm_weight_k,"
"Tensor workspace,"
"int rank,"
"int nranks,"
"float eps,"
"bool trigger_completion_at_end) -> Tensor[]");
}

TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
Expand All @@ -1896,6 +2007,8 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
m.impl("moe_allreduce", &tensorrt_llm::torch_ext::moe_allreduce);
m.impl("moe_finalize_allreduce", &tensorrt_llm::torch_ext::moe_finalize_allreduce);
m.impl("preallocate_nccl_window_buffer", &tensorrt_llm::torch_ext::preallocateNCCLWindowBuffer);
m.impl("minimax_allreduce_rms", &tensorrt_llm::torch_ext::minimax_allreduce_rms);
m.impl("minimax_allreduce_rms_qk", &tensorrt_llm::torch_ext::minimax_allreduce_rms_qk);
}

TORCH_LIBRARY_IMPL(trtllm, CPU, m)
Expand Down
14 changes: 7 additions & 7 deletions security_scanning/examples/auto_deploy/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions security_scanning/examples/draft_target_model/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions security_scanning/examples/eagle/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 7 additions & 7 deletions security_scanning/examples/llm-eval/lm-eval-harness/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions security_scanning/examples/lookahead/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions security_scanning/examples/medusa/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading