Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 0 additions & 40 deletions .github/workflows/model-registry-check.yml

This file was deleted.

1 change: 1 addition & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ HuggingFace Model → LLM API → Executor (PyTorch/AutoDeploy/TensorRT)
| `tensorrt_llm/executor/executor.py` | Execution abstraction (`GenerationExecutor`) |
| `tensorrt_llm/models/automodel.py` | Auto-discovery and model registry |
| `tensorrt_llm/_torch/models/` | PyTorch backend model implementations (distinct from `models/` used by TensorRT backend) |
| `tensorrt_llm/_torch/modules/fused_moe/MOE_DEVELOPER_GUIDE.md` | MoE architecture, backends, communication, development patterns — **read before modifying MoE code** |
| `CODING_GUIDELINES.md` | C++ and Python coding standards (referenced throughout, must read before contributing) |

## Design Patterns
Expand Down
8 changes: 6 additions & 2 deletions cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,10 +333,14 @@ class GenericLlmRequest

if (mIsStreaming && mSamplingConfig.beamWidth > 1 && mReturnGenerationLogits)
{
// In streaming mode with beam search, intermediate logits are returned before finalization,
// so they cannot be reordered to match the final beam paths. Non-streaming mode handles
// reordering in postProcessRequest() after gatherTree finalization.
TLLM_LOG_WARNING(
"Returning generation logits when streaming is enabled and beamWidth > 1 is not allowed. "
"This is because the logits may appear in irrelevant order when the beams are gathered, "
"since logits are not. Disabling returnGenerationLogits.");
"This is because intermediate logits cannot be reordered to match the final beam paths "
"until finalization. Use non-streaming mode for correct generation logits with beam search. "
"Disabling returnGenerationLogits.");
mReturnGenerationLogits = false;
}

Expand Down
176 changes: 176 additions & 0 deletions cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@

#include <algorithm>
#include <cstddef>
#include <cstring>
#include <memory>
#include <optional>
#include <stdexcept>
Expand Down Expand Up @@ -1955,6 +1956,181 @@ void TrtGptModelInflightBatching::postProcessRequest(
// store the generated tokens into the mTokensGathered buffer
llmReq.setGeneratedTokens(generatedTokens);

if (llmReq.getReturnGenerationLogits() && llmReq.getGenerationLogitsHost()
&& mWorldConfig.isLastPipelineParallelRank())
{
reorderGenerationLogitsForBeamSearch(
llmReq, seqSlot, reqBeamWidth, maxSeqLength, outputIdsHostData, sequenceLengthsHostData);
}

TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

void TrtGptModelInflightBatching::reorderGenerationLogitsForBeamSearch(LlmRequest& llmReq, SizeType32 seqSlot,
SizeType32 reqBeamWidth, SizeType32 maxSeqLength, TokenIdType const* outputIdsHostData,
SizeType32 const* sequenceLengthsHostData)
{
// Reorder generation logits to match the gathered (finalized) beam ordering.
// During generation, logits are stored indexed by beam SLOT position. After beam search
// finalization (gatherTree), output_ids are reordered by tracing parentIds to reconstruct
// the correct beam paths. However, generation_logits are NOT reordered by gatherTree.
// We fix this here by tracing parentIds on the host to build the beam-slot mapping,
// then reindexing the logits accordingly.
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);

auto const promptLen = llmReq.mPromptLen;

// Copy parentIds and ids (ungathered step IDs) from GPU to temporary host buffers.
// parentIds[slot][t] = the parent slot of beam slot `slot` at position t.
// ids[slot][t] = the token in beam slot `slot` at position t (before gather).
auto parentIdsDevice = ITensor::at(mDecoderState->getParentIds(), {seqSlot});
auto idsDevice = mDecoderState->getIds(seqSlot);

auto parentIdsHost = runtime::BufferManager::pinnedPool(parentIdsDevice->getShape(), nvinfer1::DataType::kINT32);
auto idsHost = runtime::BufferManager::pinnedPool(idsDevice->getShape(), nvinfer1::DataType::kINT32);

mCopyBufferManager.copy(*parentIdsDevice, *parentIdsHost);
mCopyBufferManager.copy(*idsDevice, *idsHost);
mCopyBufferManager.getStream().synchronize();

auto const* parentIdsData = bufferCast<TokenIdType>(*parentIdsHost);
auto const* idsData = bufferCast<TokenIdType>(*idsHost);

// For each final beam b, find the beam slot at the last generated step, then
// trace back through parentIds to build the slot trace for every generation step.
// slotTrace[beam][genStep] = the beam slot that produced the logits at that step.
auto const generationLogitsHost = llmReq.getGenerationLogitsHost();
auto const& logitsShape = generationLogitsHost->getShape();
// Non-streaming shape: [beamWidth, maxNewTokens, vocabSizePadded]
TLLM_CHECK_WITH_INFO(logitsShape.d[0] == reqBeamWidth,
"Generation logits beam dimension (%ld) does not match beam width (%d).", logitsShape.d[0], reqBeamWidth);
auto const maxNewTokens = logitsShape.d[1];
auto const vocabSizePadded = logitsShape.d[2];

std::vector<std::vector<SizeType32>> slotTrace(reqBeamWidth, std::vector<SizeType32>(maxNewTokens, 0));
bool anyReorderNeeded = false;

for (SizeType32 beam = 0; beam < reqBeamWidth; ++beam)
{
auto const seqLen = sequenceLengthsHostData[beam];
auto const genLen = seqLen - promptLen;
if (genLen <= 0)
{
continue;
}

// Find the starting beam slot at the last generated step by matching the
// backtracked token sequence against the gathered (finalized) output.
SizeType32 startSlot = -1;
for (SizeType32 s = 0; s < reqBeamWidth; ++s)
{
SizeType32 slot = s;
bool matches = true;
for (SizeType32 t = seqLen - 1; t >= promptLen; --t)
{
if (idsData[slot * maxSeqLength + t] != outputIdsHostData[beam * maxSeqLength + t])
{
matches = false;
break;
}
if (t > promptLen)
{
slot = parentIdsData[slot * maxSeqLength + t];
}
}
if (matches)
{
startSlot = s;
break;
}
}

TLLM_CHECK_WITH_INFO(startSlot >= 0,
"Could not determine beam slot mapping for beam %d during generation logits reordering.", beam);

// Build the slot trace: slotTrace[beam][g] = the pre-reassignment slot whose
// logits correspond to generation step g of this beam.
//
// The model runs BEFORE beam search reassigns beams to slots, so
// generationLogits[slot][g] was produced by the pre-reassignment slot —
// i.e. the slot the beam occupied in the *previous* step.
// parentIds[postSlot][promptLen+g] gives exactly that pre-reassignment slot,
// so taking the parentIds lookup before storing (rather than after) yields
// the correct source slot in a single pass.
SizeType32 slot = startSlot;
for (SizeType32 t = seqLen - 1; t >= promptLen; --t)
{
slot = parentIdsData[slot * maxSeqLength + t];
slotTrace[beam][t - promptLen] = slot;
}

// Check if any reordering is actually needed for this beam
auto& slotTraceIds = slotTrace[beam];
anyReorderNeeded |= std::any_of(
slotTraceIds.begin(), slotTraceIds.begin() + genLen, [beam](SizeType32 s) { return s != beam; });
}

// Reorder the generation logits in-place using a per-step temporary buffer.
if (anyReorderNeeded)
{
auto const logitsDataType = generationLogitsHost->getDataType();
auto const elemSize = runtime::BufferDataType(logitsDataType).getSize();
auto const stepSize = static_cast<size_t>(vocabSizePadded) * elemSize;

// Temp buffer for one generation step across all beams: [beamWidth, vocabSizePadded]
auto tempLogits
= runtime::BufferManager::pinnedPool(ITensor::makeShape({reqBeamWidth, vocabSizePadded}), logitsDataType);

auto* logitsPtr = static_cast<uint8_t*>(generationLogitsHost->data());
auto* tempPtr = static_cast<uint8_t*>(tempLogits->data());

std::vector<SizeType32> genLens(reqBeamWidth);
SizeType32 maxGenLen = 0;
for (SizeType32 b = 0; b < reqBeamWidth; ++b)
{
genLens[b] = std::max(SizeType32{0}, sequenceLengthsHostData[b] - promptLen);
maxGenLen = std::max(maxGenLen, genLens[b]);
}

for (SizeType32 g = 0; g < maxGenLen; ++g)
{
// Check if any beam that generated this step needs reordering
bool stepNeedsReorder = false;
for (SizeType32 b = 0; b < reqBeamWidth; ++b)
{
if (g < genLens[b] && slotTrace[b][g] != b)
{
stepNeedsReorder = true;
break;
}
}
if (!stepNeedsReorder)
{
continue;
}

// Copy all beams' logits at this step to the temp buffer
for (SizeType32 b = 0; b < reqBeamWidth; ++b)
{
// logits layout: [beamWidth, maxNewTokens, vocabSizePadded]
auto const offset = (static_cast<size_t>(b) * maxNewTokens + g) * stepSize;
std::memcpy(tempPtr + static_cast<size_t>(b) * stepSize, logitsPtr + offset, stepSize);
}

// Reorder: logits[b][g] = temp[slotTrace[b][g]]
for (SizeType32 b = 0; b < reqBeamWidth; ++b)
{
if (g >= genLens[b])
{
continue;
}
auto const dstOffset = (static_cast<size_t>(b) * maxNewTokens + g) * stepSize;
auto const srcSlot = slotTrace[b][g];
std::memcpy(logitsPtr + dstOffset, tempPtr + static_cast<size_t>(srcSlot) * stepSize, stepSize);
}
}
}

TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

Expand Down
6 changes: 6 additions & 0 deletions cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,12 @@ class TrtGptModelInflightBatching : public TrtGptModel
/// and overwrites the llmRequest tokens buffer.
/// Called either on request finishing, or at every step when doing beam search and streaming.
void postProcessRequest(LlmRequest& llmReq, std::vector<SizeType32> const& numDroppedTokens);
/// @brief Reorders generation logits to match finalized beam paths after gatherTree.
/// During beam search, logits are stored by beam slot. After finalization, output_ids are
/// reordered by parentIds, but logits are not. This method traces parentIds on the host
/// to build the slot mapping and reindexes the logits accordingly.
void reorderGenerationLogitsForBeamSearch(LlmRequest& llmReq, SizeType32 seqSlot, SizeType32 reqBeamWidth,
SizeType32 maxSeqLength, TokenIdType const* outputIdsHostData, SizeType32 const* sequenceLengthsHostData);
/// @brief Calls gatherTree (via finalize) and transmits the received data across ranks if PP>1
void getDecoderSlotHostOutputs(
SizeType32 seqSlot, bool returnLogProbs, runtime::SamplingConfig const& samplingConfig, bool streaming);
Expand Down
Loading
Loading