Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,15 @@ class BlockManager
return mWindowBlockManagers.begin()->first;
}

[[nodiscard]] SizeType32 getLastWindowSize() const
{
if (mWindowBlockManagers.empty())
{
return 0;
}
return mWindowBlockManagers.rbegin()->first;
}

[[nodiscard]] SizeType32 getNumAllocNewBlocks() const
{
return sumWindows([](auto const& manager) { return manager.getNumAllocNewBlocks(); });
Expand Down
7 changes: 5 additions & 2 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2149,10 +2149,13 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
// disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case
, mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse}
{
// When num_layers < len(maxAttentionWindowVec), not all window sizes in the
// repeating pattern are used. Update mMaxAttentionWindow to the actual
// maximum window size that has been allocated in the block manager.
mMaxAttentionWindow = mBlockManager.getLastWindowSize();

TLLM_CHECK_WITH_INFO(mSinkBlockTokenLength == 0 && mSinkBubbleLength == 0,
"[kv cache manager] streamLLM is not supported at the moment");
TLLM_CHECK_DEBUG(std::find(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end(), mMaxAttentionWindow)
!= maxAttentionWindowVec.end());
// The sink tokens are stored in blocks separate from other tokens.
// If the last block of sink tokens is only partially filled,
// we fill that block with a "bubble" to reach the number of tokens per block.
Expand Down
44 changes: 35 additions & 9 deletions cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -44,7 +44,8 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
int64_t const local_expert_offset, int64_t const local_num_experts,
std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type,
bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner, int64_t const moeConfigIndex,
torch::optional<torch::Tensor> const& topk_weights, torch::optional<torch::Tensor> const& topk_ids)
torch::optional<torch::Tensor> const& topk_weights, torch::optional<torch::Tensor> const& topk_ids,
torch::optional<torch::Tensor> const& out_tensor = torch::nullopt)
{
TORCH_CHECK(dtype == btg::Dtype::E4m3 || dtype == btg::Dtype::E2m1, "dtype can only be e4m3 or e2m1.");
TORCH_CHECK(tensorrt_llm::common::isSM100Family(), "Only SM100f is supported by FP4 block scale MOE");
Expand Down Expand Up @@ -396,9 +397,34 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
TORCH_CHECK(output2_scales_scalar.dim() == 1, "output2_scales_scalar must be 1D.");
TORCH_CHECK(output2_scales_scalar.sizes()[0] == local_num_experts, "output2_scales_scalar has incorrect dim 0.");

// allocate output
at::Tensor output = at::detail::empty_cuda(
{args.num_tokens, args.hidden_size}, at::ScalarType::BFloat16, hidden_states.device(), std::nullopt);
// allocate or use provided output
at::Tensor output;
if (out_tensor.has_value())
{
TORCH_CHECK(do_finalize, "out_tensor is only supported when do_finalize=true.");
TORCH_CHECK(out_tensor->scalar_type() == at::ScalarType::BFloat16, "out_tensor must be bfloat16.");
TORCH_CHECK(out_tensor->dim() == 2, "out_tensor must be 2D.");
TORCH_CHECK(out_tensor->sizes()[0] == args.num_tokens, "out_tensor dim0 must match num_tokens.");
TORCH_CHECK(out_tensor->device() == hidden_states.device(), "out_tensor must be on the same device as inputs.");
auto const out_hidden = out_tensor->sizes()[1];
if (out_hidden < args.hidden_size)
{
// out_tensor has unpadded hidden dim (e.g., nvfp4 with padding).
// Set valid_hidden_size so the finalize kernel writes only the needed columns.
args.valid_hidden_size = out_hidden;
args.output_hidden_size = tensorrt_llm::common::roundUp(out_hidden, static_cast<int64_t>(128));
}
else
{
TORCH_CHECK(out_hidden == args.hidden_size, "out_tensor hidden dim must match hidden_size.");
}
output = out_tensor.value();
}
else
{
output = at::detail::empty_cuda(
{args.num_tokens, args.hidden_size}, at::ScalarType::BFloat16, hidden_states.device(), std::nullopt);
}

// setup workspace
workspace.total_num_padded_tokens = total_num_padded_tokens.data_ptr<int>();
Expand Down Expand Up @@ -508,7 +534,7 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
int64_t const local_expert_offset, int64_t const local_num_experts,
std::optional<double> const routed_scaling_factor, int64_t const routing_method_type, bool const do_finalize,
std::vector<int64_t> moeConfigIndex, torch::optional<torch::Tensor> const& topk_weights,
torch::optional<torch::Tensor> const& topk_ids)
torch::optional<torch::Tensor> const& topk_ids, torch::optional<torch::Tensor> const& output = torch::nullopt)
{
// moeConfigIndex corresponds to pair (tileN, config)
auto [tileN, config] = std::tie(moeConfigIndex[0], moeConfigIndex[1]);
Expand All @@ -533,7 +559,7 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar,
num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
routed_scaling_factor, tileN, routing_method_type, do_finalize, mDtypeElt, *mRunners[tileN], config,
topk_weights, topk_ids);
topk_weights, topk_ids, output);
}

private:
Expand Down Expand Up @@ -597,7 +623,7 @@ class FP8FP4BlockScaleMoeRunner : public torch::CustomClassHolder
int64_t const local_expert_offset, int64_t const local_num_experts,
std::optional<double> const routed_scaling_factor, int64_t const routing_method_type, bool const do_finalize,
std::vector<int64_t> moeConfigIndex, torch::optional<torch::Tensor> const& topk_weights,
torch::optional<torch::Tensor> const& topk_ids)
torch::optional<torch::Tensor> const& topk_ids, torch::optional<torch::Tensor> const& output = torch::nullopt)
{
// moeConfigIndex corresponds to pair (tileN, config)
auto [tileN, config] = std::tie(moeConfigIndex[0], moeConfigIndex[1]);
Expand All @@ -621,7 +647,7 @@ class FP8FP4BlockScaleMoeRunner : public torch::CustomClassHolder
std::nullopt, std::nullopt, gemm2_weights, gemm2_weights_scale, std::nullopt, output1_scales_scalar,
output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN,
routing_method_type, do_finalize, mDtypeAct, *mRunners[tileN], config, topk_weights, topk_ids);
routing_method_type, do_finalize, mDtypeAct, *mRunners[tileN], config, topk_weights, topk_ids, output);
}

private:
Expand Down
27 changes: 20 additions & 7 deletions cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -44,7 +44,7 @@ at::Tensor run_fp8_block_scale_moe(at::optional<at::Tensor> const& routing_logit
int64_t const intermediate_size, int64_t const local_expert_offset, int64_t const local_num_experts,
std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type,
MoeRunnerType& moe_runner, int64_t moeConfigIndex, std::optional<at::Tensor> const& topk_weights,
std::optional<at::Tensor> const& topk_ids)
std::optional<at::Tensor> const& topk_ids, std::optional<at::Tensor> const& out_tensor = std::nullopt)
{
TORCH_CHECK(tensorrt_llm::common::isSM100Family(), "Only SM100f is supported by FP8 block scale MOE");

Expand Down Expand Up @@ -272,9 +272,22 @@ at::Tensor run_fp8_block_scale_moe(at::optional<at::Tensor> const& routing_logit
TORCH_CHECK(gemm2_weights_scale.sizes()[1] == args.hidden_size / 128, "gemm2_weights_scale has incorrect shape.");
TORCH_CHECK(gemm2_weights_scale.sizes()[2] == intermediate_size / 128, "gemm2_weights_scale has incorrect shape.");

// allocate output
at::Tensor output = at::detail::empty_cuda(
{args.num_tokens, args.hidden_size}, at::ScalarType::BFloat16, hidden_states.device(), std::nullopt);
// allocate or use provided output
at::Tensor output;
if (out_tensor.has_value())
{
TORCH_CHECK(out_tensor->scalar_type() == at::ScalarType::BFloat16, "out_tensor must be bfloat16.");
TORCH_CHECK(out_tensor->dim() == 2, "out_tensor must be 2D.");
TORCH_CHECK(out_tensor->sizes()[0] == args.num_tokens && out_tensor->sizes()[1] == args.hidden_size,
"out_tensor has incorrect shape.");
TORCH_CHECK(out_tensor->device() == hidden_states.device(), "out_tensor must be on the same device as inputs.");
output = out_tensor.value();
}
else
{
output = at::detail::empty_cuda(
{args.num_tokens, args.hidden_size}, at::ScalarType::BFloat16, hidden_states.device(), std::nullopt);
}

// setup workspace
workspace.total_num_padded_tokens = total_num_padded_tokens.data_ptr<int>();
Expand Down Expand Up @@ -361,7 +374,7 @@ class FP8BlockScaleMoeRunner : public torch::CustomClassHolder
int64_t const local_expert_offset, int64_t const local_num_experts,
std::optional<double> const routed_scaling_factor, int64_t routing_method_type,
std::vector<int64_t> tile_config_pair, std::optional<at::Tensor> const& topk_weights,
std::optional<at::Tensor> const& topk_ids)
std::optional<at::Tensor> const& topk_ids, std::optional<at::Tensor> const& output = std::nullopt)
{
// tile_config_pair corresponds to pair (tileN, config)
auto [tileN, config] = std::tie(tile_config_pair[0], tile_config_pair[1]);
Expand All @@ -382,7 +395,7 @@ class FP8BlockScaleMoeRunner : public torch::CustomClassHolder
return run_fp8_block_scale_moe(routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights,
gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, num_experts, top_k, n_group, topk_group,
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN,
routing_method_type, *mRunners.at(tileN), config, topk_weights, topk_ids);
routing_method_type, *mRunners.at(tileN), config, topk_weights, topk_ids, output);
}

private:
Expand Down
8 changes: 4 additions & 4 deletions cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -530,7 +530,8 @@ class Bf16MxE2m1BlockScaleMoeRunner : public torch::CustomClassHolder
std::optional<int64_t> const valid_hidden_size, std::optional<int64_t> const valid_intermediate_size,
int64_t local_expert_offset, int64_t local_num_experts, std::optional<double> routed_scaling_factor,
int64_t routing_method_type, std::vector<int64_t> moeConfigIndex,
torch::optional<torch::Tensor> const& topk_weights, torch::optional<torch::Tensor> const& topk_ids)
torch::optional<torch::Tensor> const& topk_weights, torch::optional<torch::Tensor> const& topk_ids,
torch::optional<torch::Tensor> const& output = torch::nullopt)

{
// moeConfigIndex corresponds to pair (tileN, config)
Expand All @@ -555,8 +556,7 @@ class Bf16MxE2m1BlockScaleMoeRunner : public torch::CustomClassHolder
gemm2_weights_scale, gemm2_bias, std::nullopt, std::nullopt, std::nullopt, num_experts, top_k, n_group,
topk_group, intermediate_size, valid_hidden_size, valid_intermediate_size, local_expert_offset,
local_num_experts, routed_scaling_factor, tileN, routing_method_type, mDtypeAct, *mRunners[tileN], config,
topk_weights, topk_ids,
/*out_tensor=*/torch::nullopt); // TODO: Support user-provided output
topk_weights, topk_ids, output);
}

private:
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading