Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 42 additions & 10 deletions cpp/tensorrt_llm/kernels/gptKernels.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -195,7 +195,8 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void computeSeqAndPaddingOffsets
int batchIdx = blockIdx.x;

// Compute the padding offsets.
auto compute_padding_offset = [&](int* smem_offset, int maxSeqLength, int* paddingOffsets)
auto compute_padding_offset
= [&](int* smem_offset, int maxSeqLength, int* paddingOffsets, int paddingOffsetsCapacity, char const* bufName)
{
// Block x dimension is the batch dimension, while threads iterate all tokens in the sequence.
int seqBegin = smem_offset[batchIdx];
Expand All @@ -206,42 +207,73 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void computeSeqAndPaddingOffsets
// The number of padded tokens in the previous sequences.
int paddingOffset = batchIdx * maxSeqLength - seqBegin;

// Iterate over the tokens to update the number of padded elements.
// Each write must lie within the buffer allocation (sized to numTokens on the caller side).
// If the caller's seqQLengths sum exceeds numTokens we surface the violation loudly so the
// upstream metadata bug is investigable rather than producing silent garbage offsets.
for (int tokenIdx = threadIdx.x; tokenIdx < seqLength; tokenIdx += blockDim.x)
{
paddingOffsets[seqBegin + tokenIdx] = paddingOffset;
int const idx = seqBegin + tokenIdx;
if (idx >= paddingOffsetsCapacity)
{
if (threadIdx.x == 0)
{
printf(
"[computeSeqAndPaddingOffsets] %s OOB: blockIdx=%d batchSize=%d capacity=%d seqBegin=%d "
"seqEnd=%d -- sum(seqQLengths) exceeds buffer capacity\n",
bufName, blockIdx.x, params.batchSize, paddingOffsetsCapacity, seqBegin, seqEnd);
}
__trap();
}
paddingOffsets[idx] = paddingOffset;
}
};

if (params.paddingOffsets != nullptr)
{
compute_padding_offset(smemSeqQOffsets, params.maxQSeqLength, params.paddingOffsets);
compute_padding_offset(
smemSeqQOffsets, params.maxQSeqLength, params.paddingOffsets, params.numTokens, "paddingOffsets");
}

if (need_encoder_padding_offsets)
{
compute_padding_offset(smemEncoderSeqQOffsets, params.maxEncoderQSeqLength, params.encoderPaddingOffsets);
compute_padding_offset(smemEncoderSeqQOffsets, params.maxEncoderQSeqLength, params.encoderPaddingOffsets,
params.numTokens, "encoderPaddingOffsets");
}

// Compuate tokens Info (batchIdx, tokenIdxInSeq).
// Compute tokens Info (batchIdx, tokenIdxInSeq).
if (params.tokensInfo != nullptr)
{
// The begin of the sequence.
int seqBegin = params.removePadding ? smemSeqQOffsets[batchIdx] : batchIdx * params.maxQSeqLength;
// The end of the sequence.
int seqEnd = params.removePadding ? smemSeqQOffsets[batchIdx + 1] : (batchIdx + 1) * params.maxQSeqLength;
// FIXME(Eagle): the last sequence needs to consider the paddings.
// On the last block, extend the write range to cover any trailing padding slots in the
// tokensInfo allocation so downstream kernels can read the full [0, numTokens) range.
if (batchIdx == (params.batchSize - 1))
{
seqEnd = std::max(params.numTokens, seqEnd);
}
// The length of the sequence.
int seqLength = seqEnd - seqBegin;

// Iterate over the tokens to update the number of padded elements.
// Each write must lie within the tokensInfo allocation (sized to numTokens on the caller
// side). If the caller's seqQLengths sum exceeds numTokens, surface the violation loudly so
// the upstream metadata bug is investigable rather than producing silent garbage entries.
for (int tokenIdx = threadIdx.x; tokenIdx < seqLength; tokenIdx += blockDim.x)
{
params.tokensInfo[seqBegin + tokenIdx] = make_int2(batchIdx, tokenIdx);
int const idx = seqBegin + tokenIdx;
if (idx >= params.numTokens)
{
if (threadIdx.x == 0)
{
printf(
"[computeSeqAndPaddingOffsets] tokensInfo OOB: blockIdx=%d batchSize=%d numTokens=%d "
"seqBegin=%d seqEnd=%d -- sum(seqQLengths) exceeds numTokens\n",
blockIdx.x, params.batchSize, params.numTokens, seqBegin, seqEnd);
}
__trap();
}
params.tokensInfo[idx] = make_int2(batchIdx, tokenIdx);
}
};

Expand Down
3 changes: 2 additions & 1 deletion tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[enable_block_reu
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_bf16[latency] SKIP (https://nvbugs/6012526)
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_pp4_mtp] SKIP (https://nvbugs/6018046)
test_fmha.py::test_fmha SKIP (https://nvbugs/6018058)
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] SKIP (https://nvbugs/6029882)
accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_accuracy[bf16-4-attn_dp_off-trtllm] SKIP (https://nvbugs/5919796)
accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_accuracy[fp8-4-attn_dp_off-trtllm] SKIP (https://nvbugs/6058066)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True-enable_chunked_prefill=False-v2_kv_cache=False] SKIP (https://nvbugs/6027594)
examples/test_visual_gen.py::test_vbench_dimension_score_wan22_a14b_nvfp4 SKIP (https://nvbugs/6050483)
visual_gen/test_visual_gen_benchmark.py::test_online_benchmark[openai-videos] SKIP (https://nvbugs/6050483)
Expand Down
Loading