Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,518 changes: 1,441 additions & 77 deletions .pre-commit-config.yaml

Large diffs are not rendered by default.

113 changes: 113 additions & 0 deletions CODING_GUIDELINES.md
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,119 @@ Do not use Protocol when a shared base class or ABC already exists and implement

Note that TypeVars can also be bound to `Protocol`s. Use this feature to specify the expected interface for an argument to a generic function if duck typing is desired.

## Pre-commit Linting (Supplemental Rules)

Python files are split into two groups with separate lint toolchains:

| Group | Files | Formatting | Linting |
|-------|-------|-----------|---------|
| **A (modern)** | ~550 files | ruff format (100-char) | Full ruff rules |
| **B (legacy)** | ~1,350 files (listed in `legacy-files.txt`) | yapf (80-char) + isort + autoflake | Supplemental ruff rules, baseline-gated |

Key terminology used throughout this section:

- **Legacy files** — Python files listed in `legacy-files.txt` that haven't been migrated to the modern ruff toolchain yet.
- **Known violations** — pre-existing lint issues in legacy files, tracked in `ruff-legacy-baseline.json` (the "violation snapshot"). The snapshot records per-file, per-rule violation counts.

### Group A files

When you modify a Group A file and commit, ruff will:
1. Format the file (100-char line width)
2. Lint the **entire file** with the full rule set
3. Auto-fix what it can; report remaining issues for you to fix

### Group B (legacy) files

When you modify a Group B file and commit, the legacy tools handle formatting
(isort, yapf, autoflake) and the `ruff-legacy` hook applies supplemental lint
rules that the legacy tools don't cover (e.g., bare except, undefined names,
invalid escape sequences).

The hook is **baseline-gated** — it runs in both local pre-commit and CI:
1. Runs `ruff --fix` on staged legacy files, auto-fixing what it can
2. Compares remaining violations against the violation snapshot
3. If your change introduces a **new** violation (count exceeds the snapshot), the commit is blocked
4. Pre-existing known violations are tolerated — they won't block you

### Handling a new lint violation

If the `ruff-legacy` hook blocks your commit, fix the violation. The hook only
blocks on *new* violations your code introduced; pre-existing issues are already
accounted for in the snapshot.

### Reducing known violations

For developers who want to clean up existing tech debt in legacy files:

- **Auto-fixable violations**: `ruff check --config ruff-legacy.toml --fix` fixes the easy ones
(unused imports, f-string conversions, comparison order, etc.)
- **Manual violations**: Many violations (bare excepts, undefined names, shadowed imports) require
human judgment. Run `ruff check --config ruff-legacy.toml <file>` to see what remains.

After any batch cleanup, update the violation snapshot:
```bash
python scripts/legacy_utils.py lint-update-violations
```
Commit the fixed files and updated snapshot together.

The hook prints a hint when it detects your change reduced violations below the
snapshot counts, suggesting you run `--update-baseline` to tighten the ratchet.
This is informational — not blocking.

### Graduating a file from Group B to Group A

1. Remove the path from `legacy-files.txt`
2. Regenerate derived configs: `python scripts/legacy_utils.py gen-configs`
3. Update the violation snapshot: `python scripts/legacy_utils.py lint-update-violations`
4. Fix all violations under the main ruff ruleset: `ruff check --fix <file> && ruff format <file>`
5. Commit everything together (regenerated configs + snapshot + formatted file)

### Maintenance

`legacy-files.txt` is the **single source of truth** for which files are legacy. Three derived
configs are auto-generated from it — **never edit by hand**:

- `ruff-legacy.toml` (the ruff config with `include` list)
- Auto-generated blocks in `pyproject.toml` (`[tool.ruff.format]` exclude list)
- Auto-generated blocks in `.pre-commit-config.yaml` (regex file anchors)

The `verify-legacy-config` pre-commit hook catches stale configs: it regenerates expected content
from `legacy-files.txt` and diffs against the actual files. If they don't match, it fails and
tells you to run `python scripts/legacy_utils.py gen-configs`.

`ruff-legacy-baseline.json` (the violation snapshot) is a separate artifact — update it with
`python scripts/legacy_utils.py lint-update-violations`.

**Dependency chain** — after editing `legacy-files.txt`, two downstream artifacts must be updated:
```bash
# After editing legacy-files.txt (manually or via --prune):

# 1. Regenerate derived configs (ruff-legacy.toml, pyproject.toml blocks, pre-commit anchors)
python scripts/legacy_utils.py gen-configs

# 2. Update the violation snapshot (sync with the new file list)
python scripts/legacy_utils.py lint-update-violations
```

**Periodic housekeeping** — when files listed in `legacy-files.txt` are deleted or renamed by
unrelated PRs, their entries become stale. This is not fatal and doesn't break anything — no hook
matches a non-existent file — but the file list and violation snapshot can accumulate noise over
time. The generate command warns about stale entries and suggests running `prune-files`:
```bash
python scripts/legacy_utils.py prune-files # cleans legacy-files.txt
python scripts/legacy_utils.py gen-configs # regenerate configs
python scripts/legacy_utils.py lint-update-violations # sync snapshot
```

**Keeping the violation snapshot current** — if your change reduces the number of known lint
violations in any tracked legacy file, it is heavily recommended (but not yet enforced) to update
the violation snapshot so the ratchet tightens. The pre-commit hook prints a hint when it detects
reductions. To update:
```bash
python scripts/legacy_utils.py lint-update-violations
```
Commit the updated `ruff-legacy-baseline.json` alongside your changes.

## Documentation Guidelines

#### CLI Options in Documentation
Expand Down
8 changes: 8 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ mdformat.................................................................Passed

If any files were modified by this hook, you will need to stage and commit them again.

> **Note:** Python files are split into two groups. **Group A** files get full
> ruff formatting and linting. **Group B** (legacy) files get yapf/isort/autoflake
> formatting plus supplemental ruff lint rules via the `ruff-legacy` hook.
> The legacy hook is baseline-gated: pre-existing violations are tolerated, but
> new violations introduced by your change will block the commit.
> See [CODING_GUIDELINES.md](CODING_GUIDELINES.md#pre-commit-linting-supplemental-rules)
> for details on the two-group system and how to graduate files.

In addition, please try to keep pull requests (PRs) as concise as possible:
* Avoid committing commented-out code.
* Wherever possible, each PR should address a single concern. If there are several otherwise-unrelated things that should be fixed to reach a desired endpoint, our recommendation is to open several PRs and indicate the dependencies in the description. The more complex the changes are in a single PR, the more time it will take to review those changes.
Expand Down
2 changes: 2 additions & 0 deletions cpp/include/tensorrt_llm/common/cudaUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ template<> struct packed_as<half, 2> { using type =
template<> struct packed_as<float, 2> { using type = float2; };
template<> struct packed_as<int8_t, 2> { using type = int16_t; };
template<> struct packed_as<int32_t, 2> { using type = int2; };
template<> struct packed_as<uint, 2> { using type = uint2; };
template<> struct packed_as<uint, 4> { using type = uint4; };
template<> struct packed_as<half2, 1> { using type = half; };
template<> struct packed_as<float2, 1> { using type = float; };
#ifdef ENABLE_BF16
Expand Down
252 changes: 252 additions & 0 deletions cpp/tensorrt_llm/kernels/fusedDiTQKNormRopeKernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "fusedDiTQKNormRopeKernel.h"
#include "tensorrt_llm/common/config.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/mathUtils.h"
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
#include <cuda_bf16.h>
#include <cuda_runtime.h>

TRTLLM_NAMESPACE_BEGIN

namespace kernels
{

////////////////////////////////////////////////////////////////////////////////////////////////////
//
// Per-head QK Norm + RoPE kernel (FLUX, Cosmos3, UniVideo)
//
// Each warp processes one head of one token (Q or K only; V is untouched).
// Supports:
// - Precomputed cos/sin embeddings
// - Dual-stream attention (text vs image norm weights)
// - Interleaved or rotate_half RoPE modes
//
template <int head_dim, bool interleave>
__global__ void fusedDiTQKNormRopeKernel(__nv_bfloat16* qkv, // [num_tokens, total_heads * head_dim]
int const num_heads_q, int const num_heads_k, int const num_heads_v, float const eps,
__nv_bfloat16 const* q_weight, // [head_dim]
__nv_bfloat16 const* k_weight, // [head_dim]
__nv_bfloat16 const* q_add_weight, // [head_dim] or nullptr
__nv_bfloat16 const* k_add_weight, // [head_dim] or nullptr
float const* cos_emb, // [num_tokens, head_dim]
float const* sin_emb, // [num_tokens, head_dim]
int const num_tokens, int const num_txt_tokens,
int const tokens_per_batch) // seq_len per batch element; 0 = flat (no batching)
{
int const warpsPerBlock = blockDim.x / 32;
int const warpId = threadIdx.x / 32;
int const laneId = threadIdx.x % 32;

int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId;

int const total_qk_heads = num_heads_q + num_heads_k;

// Map warp → (token, head type)
int const tokenIdx = globalWarpIdx / total_qk_heads;
int const localHeadIdx = globalWarpIdx % total_qk_heads;

if (tokenIdx >= num_tokens)
{
return;
}

bool const isQ = localHeadIdx < num_heads_q;
int const headIdx = isQ ? localHeadIdx : localHeadIdx - num_heads_q;

int const num_heads = num_heads_q + num_heads_k + num_heads_v;

// Each warp (32 threads) processes one head of head_dim elements.
static_assert(
head_dim % (32 * 2) == 0, "head_dim must be divisible by 64 (each warp thread gets even number of elements)");
constexpr int numElemsPerThread = head_dim / 32;
float elements[numElemsPerThread];
constexpr int elemSizeBytes = numElemsPerThread * sizeof(__nv_bfloat16);
static_assert(elemSizeBytes % 4 == 0, "elemSizeBytes must be a multiple of 4");
constexpr int vecSize = elemSizeBytes / 4;
using vec_T = typename tensorrt_llm::common::packed_as<uint, vecSize>::type;

// Compute offset into packed QKV tensor (use int64_t to avoid overflow
// when num_tokens * num_heads * head_dim > INT_MAX, e.g. WAN I2V 14B)
int64_t offsetWarp;
if (isQ)
{
offsetWarp = static_cast<int64_t>(tokenIdx) * num_heads * head_dim + headIdx * head_dim;
}
else
{
offsetWarp
= static_cast<int64_t>(tokenIdx) * num_heads * head_dim + num_heads_q * head_dim + headIdx * head_dim;
}
int64_t offsetThread = offsetWarp + laneId * numElemsPerThread;

// ---- Step 1: Load elements and compute sum of squares ----
float sumOfSquares = 0.0f;
{
vec_T vec = *reinterpret_cast<vec_T const*>(&qkv[offsetThread]);
for (int i = 0; i < vecSize; i++)
{
float2 vals = __bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(reinterpret_cast<uint*>(&vec) + i));
sumOfSquares += vals.x * vals.x;
sumOfSquares += vals.y * vals.y;
elements[2 * i] = vals.x;
elements[2 * i + 1] = vals.y;
}
}

// ---- Step 2: RMS normalization with dual-stream weight selection ----
sumOfSquares = tensorrt_llm::common::warpReduceSum(sumOfSquares);
float rms_rcp = rsqrtf(sumOfSquares / static_cast<float>(head_dim) + eps);

// Select norm weight: text tokens use add_weight (if provided), image tokens use primary weight.
// For batched input (B*S flattened), use modulo to get local token index within each batch element.
int const localTokenIdx = (tokens_per_batch > 0) ? (tokenIdx % tokens_per_batch) : tokenIdx;
bool const useAddWeight = (num_txt_tokens > 0) && (localTokenIdx < num_txt_tokens);

__nv_bfloat16 const* weight_ptr;
if (isQ)
{
weight_ptr = (useAddWeight && q_add_weight != nullptr) ? q_add_weight : q_weight;
}
else
{
weight_ptr = (useAddWeight && k_add_weight != nullptr) ? k_add_weight : k_weight;
}

for (int i = 0; i < numElemsPerThread; i++)
{
int dim = laneId * numElemsPerThread + i;
float weight = __bfloat162float(weight_ptr[dim]);
elements[i] *= rms_rcp * weight;
}

// ---- Step 3: Apply RoPE with precomputed cos/sin ----
int64_t const embOffset = static_cast<int64_t>(tokenIdx) * head_dim;

if constexpr (interleave)
{
// Interleaved pairing: (element[2i], element[2i+1])
for (int i = 0; i < numElemsPerThread; i += 2)
{
int dim = laneId * numElemsPerThread + i;
float cos0 = cos_emb[embOffset + dim];
float sin0 = sin_emb[embOffset + dim];
float cos1 = cos_emb[embOffset + dim + 1];
float sin1 = sin_emb[embOffset + dim + 1];

float x = elements[i];
float y = elements[i + 1];

elements[i] = x * cos0 + (-y) * sin0;
elements[i + 1] = y * cos1 + x * sin1;
}
}
else
{
// rotate_half pairing: element[i] pairs with element[i + D/2].
// Each of the 32 lanes owns numElemsPerThread = D/32 consecutive elements,
// so the partner element at offset D/2 lives in the lane that is
// (D/2) / (D/32) = 16 lanes away. XOR with 16 swaps the two halves.
__syncwarp();
constexpr int pairOffset = 16;

float partner[numElemsPerThread];
for (int i = 0; i < numElemsPerThread; i++)
{
partner[i] = __shfl_xor_sync(0xffffffff, elements[i], pairOffset);
// First half (laneId < pairOffset): rotate_half = [-partner, self]
// result[i] = elements[i] * cos - partner[i] * sin
if (laneId < pairOffset)
{
partner[i] = -partner[i];
}
}
__syncwarp();

for (int i = 0; i < numElemsPerThread; i++)
{
int dim = laneId * numElemsPerThread + i;
float cos_val = cos_emb[embOffset + dim];
float sin_val = sin_emb[embOffset + dim];
elements[i] = elements[i] * cos_val + partner[i] * sin_val;
}
}

// ---- Step 4: Store back ----
{
vec_T vec;
for (int i = 0; i < vecSize; i++)
{
__nv_bfloat162 vals = __float22bfloat162_rn(make_float2(elements[2 * i], elements[2 * i + 1]));
reinterpret_cast<__nv_bfloat162&>(*(reinterpret_cast<uint*>(&vec) + i)) = vals;
}
vec_T* outputPtr = reinterpret_cast<vec_T*>(&qkv[offsetThread]);
*outputPtr = vec;
}
}

////////////////////////////////////////////////////////////////////////////////////////////////////

void launchFusedDiTQKNormRope(void* qkv, int num_tokens, int num_heads_q, int num_heads_k, int num_heads_v,
int head_dim, float eps, void const* q_weight, void const* k_weight, void const* q_add_weight,
void const* k_add_weight, float const* cos_emb, float const* sin_emb, int num_txt_tokens, bool interleave,
int tokens_per_batch, cudaStream_t stream)
{
constexpr int blockSize = 256;

int const warpsPerBlock = blockSize / 32;
int const totalQKHeads = num_heads_q + num_heads_k;
int const totalWarps = num_tokens * totalQKHeads;

int const gridSize = common::divUp(totalWarps, warpsPerBlock);
dim3 gridDim(gridSize);
dim3 blockDim(blockSize);

#define LAUNCH_PER_HEAD_KERNEL(HEAD_DIM, INTERLEAVE) \
fusedDiTQKNormRopeKernel<HEAD_DIM, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps, \
reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight), \
reinterpret_cast<__nv_bfloat16 const*>(q_add_weight), reinterpret_cast<__nv_bfloat16 const*>(k_add_weight), \
cos_emb, sin_emb, num_tokens, num_txt_tokens, tokens_per_batch)

if (interleave)
{
switch (head_dim)
{
case 64: LAUNCH_PER_HEAD_KERNEL(64, true); break;
case 128: LAUNCH_PER_HEAD_KERNEL(128, true); break;
case 256: LAUNCH_PER_HEAD_KERNEL(256, true); break;
default: TLLM_THROW("Unsupported head dimension for fusedDiTQKNormRope: %d", head_dim);
}
}
else
{
switch (head_dim)
{
case 64: LAUNCH_PER_HEAD_KERNEL(64, false); break;
case 128: LAUNCH_PER_HEAD_KERNEL(128, false); break;
case 256: LAUNCH_PER_HEAD_KERNEL(256, false); break;
default: TLLM_THROW("Unsupported head dimension for fusedDiTQKNormRope: %d", head_dim);
}
}
#undef LAUNCH_PER_HEAD_KERNEL
}

} // namespace kernels

TRTLLM_NAMESPACE_END
Loading
Loading