Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/nanobind/thop/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ void initBindings(nb::module_& m)
nb::arg("cu_kv_seqlens") = std::nullopt, nb::arg("fmha_scheduler_counter") = std::nullopt,
nb::arg("mla_bmm1_scale") = std::nullopt, nb::arg("mla_bmm2_scale") = std::nullopt,
nb::arg("quant_q_buffer") = std::nullopt, nb::arg("flash_mla_tile_scheduler_metadata") = std::nullopt,
nb::arg("flash_mla_num_splits") = std::nullopt, "Multi-head attention operation",
nb::call_guard<nb::gil_scoped_release>());
nb::arg("flash_mla_num_splits") = std::nullopt, nb::arg("num_contexts") = 0, nb::arg("num_ctx_tokens") = 0,
"Multi-head attention operation", nb::call_guard<nb::gil_scoped_release>());

m.def(
"get_helix_workspace_size_per_rank",
Expand Down
101 changes: 49 additions & 52 deletions cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,69 +28,66 @@ TRTLLM_NAMESPACE_BEGIN
namespace torch_ext
{

void indexer_k_cache_scatter_op(th::Tensor const& k_fp8_bytes, th::Tensor const& k_scale_bytes, th::Tensor& k_cache,
th::Tensor const& slot_mapping_fp8, th::Tensor const& slot_mapping_scale)
void indexer_k_cache_scatter_op(th::Tensor const& k_fp8, th::Tensor const& k_scale, th::Tensor& k_cache,
th::Tensor const& slot_mapping_fp8, th::Tensor const& slot_mapping_scale, int64_t num_tokens)
{
// Validate all tensors are CUDA tensors
TORCH_CHECK(k_fp8_bytes.is_cuda() && k_scale_bytes.is_cuda() && k_cache.is_cuda() && slot_mapping_fp8.is_cuda()
// k_fp8: [>=num_tokens, head_dim] in FP8 (1 byte/element) — reinterpreted as uint8
// k_scale: [>=num_tokens, head_dim // quant_block_size] in float32 — reinterpreted as uint8 bytes
// slot_mapping_fp8, slot_mapping_scale: [>=num_tokens] int64 — only first num_tokens used
// k_cache: [num_blocks, block_size, 1, per_token_size] uint8

TORCH_CHECK(k_fp8.is_cuda() && k_scale.is_cuda() && k_cache.is_cuda() && slot_mapping_fp8.is_cuda()
&& slot_mapping_scale.is_cuda(),
"All tensors must be CUDA tensors");

// Validate tensor dimensions
TORCH_CHECK(k_fp8_bytes.dim() == 2, "k_fp8_bytes must be a 2D Tensor [num_tokens, head_dim]");
TORCH_CHECK(k_scale_bytes.dim() == 2, "k_scale_bytes must be a 2D Tensor [num_tokens, scale_size]");
TORCH_CHECK(slot_mapping_fp8.dim() == 1, "slot_mapping_fp8 must be a 1D Tensor [num_tokens]");
TORCH_CHECK(slot_mapping_scale.dim() == 1, "slot_mapping_scale must be a 1D Tensor [num_tokens]");

// Enforce k_cache is 4D tensor
TORCH_CHECK(k_cache.dim() == 4,
"k_cache must be a 4D Tensor [num_blocks, block_size, 1, per_token_size], got %d dimensions",
TORCH_CHECK(k_fp8.dim() == 2, "k_fp8 must be 2D [num_tokens, head_dim]");
TORCH_CHECK(k_scale.dim() == 2, "k_scale must be 2D [num_tokens, scale_elements]");
TORCH_CHECK(slot_mapping_fp8.dim() == 1, "slot_mapping_fp8 must be 1D [num_tokens]");
TORCH_CHECK(slot_mapping_scale.dim() == 1, "slot_mapping_scale must be 1D [num_tokens]");
TORCH_CHECK(k_cache.dim() == 4, "k_cache must be 4D [num_blocks, block_size, 1, per_token_size], got %d dims",
static_cast<int>(k_cache.dim()));

// Validate tensor dtypes
TORCH_CHECK(k_fp8_bytes.scalar_type() == torch::kUInt8, "k_fp8_bytes must be uint8");
TORCH_CHECK(k_scale_bytes.scalar_type() == torch::kUInt8, "k_scale_bytes must be uint8");
// Validate tensor dtypes — reinterpret_cast below assumes specific element sizes
TORCH_CHECK(k_fp8.element_size() == 1, "k_fp8 must have 1-byte elements (e.g. FP8), got %d", k_fp8.element_size());
TORCH_CHECK(k_scale.element_size() == 4, "k_scale must have 4-byte elements (e.g. float32), got %d",
k_scale.element_size());
TORCH_CHECK(slot_mapping_fp8.scalar_type() == torch::kInt64, "slot_mapping_fp8 must be int64");
TORCH_CHECK(slot_mapping_scale.scalar_type() == torch::kInt64, "slot_mapping_scale must be int64");

// Validate tensor shapes are consistent
auto num_tokens = static_cast<int32_t>(k_fp8_bytes.size(0));
TORCH_CHECK(
k_scale_bytes.size(0) == num_tokens, "k_scale_bytes first dimension must equal k_fp8_bytes first dimension");
TORCH_CHECK(slot_mapping_fp8.size(0) == num_tokens, "slot_mapping_fp8 length must equal num_tokens");
TORCH_CHECK(slot_mapping_scale.size(0) == num_tokens, "slot_mapping_scale length must equal num_tokens");

// Validate tensors are contiguous (except k_cache which may be non-contiguous)
TORCH_CHECK(k_fp8_bytes.is_contiguous(), "k_fp8_bytes must be contiguous");
TORCH_CHECK(k_scale_bytes.is_contiguous(), "k_scale_bytes must be contiguous");
// k_cache can be non-contiguous - we handle this via strides
TORCH_CHECK(k_fp8.is_contiguous(), "k_fp8 must be contiguous");
TORCH_CHECK(k_scale.is_contiguous(), "k_scale must be contiguous");
TORCH_CHECK(slot_mapping_fp8.is_contiguous(), "slot_mapping_fp8 must be contiguous");
TORCH_CHECK(slot_mapping_scale.is_contiguous(), "slot_mapping_scale must be contiguous");

int32_t head_dim = static_cast<int32_t>(k_fp8_bytes.size(1)); // head_dim = quant_block_size = 128
int32_t scale_size = static_cast<int32_t>(k_scale_bytes.size(1)); // scale_size = 4 bytes

int32_t cache_dim_0 = static_cast<int32_t>(k_cache.size(0)); // num_blocks
int32_t cache_dim_1 = static_cast<int32_t>(k_cache.size(1)); // block_size
int32_t cache_dim_2 = static_cast<int32_t>(k_cache.size(2)); // num_kv_heads
int32_t cache_dim_3 = static_cast<int32_t>(k_cache.size(3)); // per_token_size

// Validation for indexer k cache pool for DeepSeek-V3.2 constraints
TORCH_CHECK(cache_dim_2 == 1, "k_cache dimension 2 must be 1 for DeepSeek-V3.2, got %d", cache_dim_2);
TORCH_CHECK(head_dim == 128, "k_fp8_bytes head_dim must be 128 for DeepSeek-V3.2, got %d", head_dim);
TORCH_CHECK(scale_size == 4, "k_scale_bytes scale_size must be 4 bytes for DeepSeek-V3.2, got %d", scale_size);

int64_t cache_stride_0 = static_cast<int64_t>(k_cache.stride(0));
int64_t cache_stride_1 = static_cast<int64_t>(k_cache.stride(1));
int64_t cache_stride_2 = static_cast<int64_t>(k_cache.stride(2));
int64_t cache_stride_3 = static_cast<int64_t>(k_cache.stride(3));

auto stream = at::cuda::getCurrentCUDAStream(k_fp8_bytes.get_device());

tk::invokeIndexerKCacheScatter(k_fp8_bytes.data_ptr<uint8_t>(), k_scale_bytes.data_ptr<uint8_t>(),
k_cache.data_ptr<uint8_t>(), slot_mapping_fp8.data_ptr<int64_t>(), slot_mapping_scale.data_ptr<int64_t>(),
num_tokens, head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0,
cache_stride_1, cache_stride_2, cache_stride_3, stream);
// FP8 is 1 byte per element, so head_dim in elements == head_dim in bytes.
int32_t const head_dim = static_cast<int32_t>(k_fp8.size(1));
// Scale size in bytes: num_scale_elements * bytes_per_element.
int32_t const scale_size = static_cast<int32_t>(k_scale.size(1)) * static_cast<int32_t>(k_scale.element_size());

int32_t const cache_dim_0 = static_cast<int32_t>(k_cache.size(0));
int32_t const cache_dim_1 = static_cast<int32_t>(k_cache.size(1));
int32_t const cache_dim_2 = static_cast<int32_t>(k_cache.size(2));
int32_t const cache_dim_3 = static_cast<int32_t>(k_cache.size(3));

TORCH_CHECK(cache_dim_2 == 1, "k_cache dimension 2 must be 1, got %d", cache_dim_2);
TORCH_CHECK(head_dim == 128, "k_fp8 head_dim must be 128, got %d", head_dim);
TORCH_CHECK(scale_size == 4, "k_scale scale_size must be 4 bytes, got %d", scale_size);

int64_t const cache_stride_0 = static_cast<int64_t>(k_cache.stride(0));
int64_t const cache_stride_1 = static_cast<int64_t>(k_cache.stride(1));
int64_t const cache_stride_2 = static_cast<int64_t>(k_cache.stride(2));
int64_t const cache_stride_3 = static_cast<int64_t>(k_cache.stride(3));

auto stream = at::cuda::getCurrentCUDAStream(k_fp8.get_device());

// Reinterpret k_fp8 as uint8 bytes and k_scale as raw bytes via data_ptr.
// For slot mappings, use data_ptr directly — only the first num_tokens entries are read.
tk::invokeIndexerKCacheScatter(reinterpret_cast<uint8_t const*>(k_fp8.data_ptr()),
reinterpret_cast<uint8_t const*>(k_scale.data_ptr()), k_cache.data_ptr<uint8_t>(),
slot_mapping_fp8.data_ptr<int64_t>(), slot_mapping_scale.data_ptr<int64_t>(), static_cast<int32_t>(num_tokens),
head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0, cache_stride_1,
cache_stride_2, cache_stride_3, stream);
}

} // namespace torch_ext
Expand All @@ -100,8 +97,8 @@ TRTLLM_NAMESPACE_END
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"indexer_k_cache_scatter_op(Tensor k_fp8_bytes, Tensor k_scale_bytes, Tensor(a!) k_cache, "
"Tensor slot_mapping_fp8, Tensor slot_mapping_scale) -> ()");
"indexer_k_cache_scatter_op(Tensor k_fp8, Tensor k_scale, Tensor(a!) k_cache, "
"Tensor slot_mapping_fp8, Tensor slot_mapping_scale, int num_tokens) -> ()");
}

TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
Expand Down
18 changes: 4 additions & 14 deletions cpp/tensorrt_llm/thop/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer,
std::optional<torch::Tensor> flash_mla_tile_scheduler_metadata, std::optional<torch::Tensor> flash_mla_num_splits)
std::optional<torch::Tensor> flash_mla_tile_scheduler_metadata, std::optional<torch::Tensor> flash_mla_num_splits,
int64_t num_contexts, int64_t num_ctx_tokens)
{
TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx);
// Use these tensors to infer if the attention is using KV cache
Expand Down Expand Up @@ -833,20 +834,9 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
}
bool const is_gen_only = attn_input_type == AttentionInputType::GenerationOnly;

int32_t num_contexts = 0;
// count context requests
for (int32_t idx = 0; idx < num_seqs; idx++)
{
if (request_types[idx] != RequestType::kCONTEXT)
{
break;
}
++num_contexts;
}
int32_t const num_generations = num_seqs - num_contexts;
int32_t const num_generations = num_seqs - static_cast<int32_t>(num_contexts);
int32_t const num_tokens = qkv_or_q.size(0);
int32_t const num_ctx_tokens = host_context_lengths.slice(0, 0, num_contexts).sum().item<int32_t>();
int32_t const num_gen_tokens = is_gen_only ? num_tokens : num_tokens - num_ctx_tokens;
int32_t const num_gen_tokens = is_gen_only ? num_tokens : num_tokens - static_cast<int32_t>(num_ctx_tokens);
auto const ctx_total_kv_len = host_total_kv_lens.index({0}).item<int32_t>();
auto const gen_total_kv_len = host_total_kv_lens.index({1}).item<int32_t>();

Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/thop/attentionOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer,
std::optional<torch::Tensor> flash_mla_tile_scheduler_metadata = std::nullopt,
std::optional<torch::Tensor> flash_mla_num_splits = std::nullopt);
std::optional<torch::Tensor> flash_mla_num_splits = std::nullopt, int64_t num_contexts = 0,
int64_t num_ctx_tokens = 0);

struct KvCachePoolPointers
{
Expand Down
2 changes: 1 addition & 1 deletion examples/visual_gen/quickstart_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def main():
inputs="A cat sitting on a windowsill",
params=params,
)
MediaStorage.save_video(output.video, "output.avi", frame_rate=params.frame_rate)
MediaStorage.save_video(output.video, "output.avi", frame_rate=16.0)


if __name__ == "__main__":
Expand Down
25 changes: 14 additions & 11 deletions examples/visual_gen/visual_gen_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,18 @@ def main():

start_time = time.time()

inputs = {
"prompt": args.prompt,
"negative_prompt": args.negative_prompt,
inputs = {"prompt": args.prompt}

extra_params = {
"guidance_rescale": args.guidance_rescale,
"stg_scale": args.stg_scale,
"modality_scale": args.modality_scale,
"rescale_scale": args.rescale_scale,
"guidance_skip_step": args.guidance_skip_step,
"enhance_prompt": args.enhance_prompt,
}
if args.stg_blocks is not None:
extra_params["stg_blocks"] = args.stg_blocks

params = VisualGenParams(
height=args.height,
Expand All @@ -302,15 +310,10 @@ def main():
seed=args.seed,
num_frames=args.num_frames,
frame_rate=args.frame_rate,
guidance_rescale=args.guidance_rescale,
input_reference=args.image,
negative_prompt=args.negative_prompt,
image=args.image,
image_cond_strength=args.image_cond_strength,
stg_scale=args.stg_scale,
stg_blocks=args.stg_blocks,
modality_scale=args.modality_scale,
rescale_scale=args.rescale_scale,
guidance_skip_step=args.guidance_skip_step,
enhance_prompt=args.enhance_prompt,
extra_params=extra_params,
)

output = visual_gen.generate(inputs=inputs, params=params)
Expand Down
20 changes: 12 additions & 8 deletions examples/visual_gen/visual_gen_wan_i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,22 +224,26 @@ def main():

start_time = time.time()

extra_params = {}
if args.last_image_path:
extra_params["last_image"] = args.last_image_path
if args.guidance_scale_2 is not None:
extra_params["guidance_scale_2"] = args.guidance_scale_2
if args.boundary_ratio is not None:
extra_params["boundary_ratio"] = args.boundary_ratio

output = visual_gen.generate(
inputs={
"prompt": args.prompt,
"negative_prompt": args.negative_prompt,
},
inputs={"prompt": args.prompt},
params=VisualGenParams(
height=args.height,
width=args.width,
num_inference_steps=args.steps,
guidance_scale=args.guidance_scale,
seed=args.seed,
num_frames=args.num_frames,
input_reference=args.image_path,
last_image=args.last_image_path if args.last_image_path else None,
guidance_scale_2=args.guidance_scale_2,
boundary_ratio=args.boundary_ratio,
negative_prompt=args.negative_prompt,
image=args.image_path,
extra_params=extra_params if extra_params else None,
),
)

Expand Down
15 changes: 9 additions & 6 deletions examples/visual_gen/visual_gen_wan_t2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,20 +230,23 @@ def main():

start_time = time.time()

extra_params = {}
if args.guidance_scale_2 is not None:
extra_params["guidance_scale_2"] = args.guidance_scale_2
if args.boundary_ratio is not None:
extra_params["boundary_ratio"] = args.boundary_ratio

output = visual_gen.generate(
inputs={
"prompt": args.prompt,
"negative_prompt": args.negative_prompt,
},
inputs={"prompt": args.prompt},
params=VisualGenParams(
height=args.height,
width=args.width,
num_inference_steps=args.steps,
guidance_scale=args.guidance_scale,
seed=args.seed,
num_frames=args.num_frames,
guidance_scale_2=args.guidance_scale_2,
boundary_ratio=args.boundary_ratio,
negative_prompt=args.negative_prompt,
extra_params=extra_params if extra_params else None,
),
)

Expand Down
9 changes: 6 additions & 3 deletions tensorrt_llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -132,8 +132,9 @@ def _setup_vendored_triton_kernels():
from .python_plugin import PluginBase
from .sampling_params import SamplingParams
from .version import __version__
from .visual_gen import (VisualGen, VisualGenArgs, VisualGenError,
VisualGenParams, VisualGenResult)
from .visual_gen import (ExtraParamSchema, VisualGen, VisualGenArgs,
VisualGenError, VisualGenParams, VisualGenParamsError,
VisualGenResult)

__all__ = [
'AutoConfig',
Expand Down Expand Up @@ -182,7 +183,9 @@ def _setup_vendored_triton_kernels():
'TrtLlmArgs',
'SamplingParams',
'VisualGenArgs',
'ExtraParamSchema',
'VisualGenError',
'VisualGenParamsError',
'VisualGenResult',
'DisaggregatedParams',
'KvCacheConfig',
Expand Down
32 changes: 8 additions & 24 deletions tensorrt_llm/_torch/attention_backend/sparse/dsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,34 +1458,18 @@ def _update_k_cache(self, k_fp8: torch.Tensor, k_scale: torch.Tensor,
if metadata.kv_cache_manager is None or metadata.slot_mapping_fp8 is None:
return

# [num_blocks, block_size, 1, per_token_size ]
k_cache = metadata.kv_cache_manager.get_indexer_k_cache_buffers(
self.layer_idx)

num_tokens = k_fp8.shape[0]
head_dim = k_fp8.shape[1]
scale_size = k_scale.shape[1] * 4 # Convert to bytes (float32 = 4 bytes)

# Convert to bytes: flatten first, then view as uint8, then reshape
k_fp8_bytes = k_fp8.view(-1).view(torch.uint8).view(
num_tokens, head_dim)

# k_scale: for single-element tensors, contiguous() may be no-op
# Fix stride(-1) for byte-level view
k_scale_flat = k_scale.view(-1)
if k_scale_flat.stride(-1) != 1:
k_scale_flat = torch.as_strided(k_scale_flat.contiguous(),
size=(k_scale_flat.numel(), ),
stride=(1, ))
k_scale_bytes = k_scale_flat.view(torch.uint8).view(
num_tokens, scale_size)

# Use CUDA kernel to scatter FP8 and scale bytes into cache
flat_indices_fp8 = metadata.slot_mapping_fp8[:num_tokens]
flat_indices_scale = metadata.slot_mapping_scale[:num_tokens]
torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8_bytes, k_scale_bytes,
k_cache, flat_indices_fp8,
flat_indices_scale)

# The C++ op reinterprets k_fp8 (FP8) and k_scale (float32) as raw
# bytes internally and only reads the first num_tokens entries from
# the slot mapping buffers, avoiding Python-side view/slice overhead.
torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8, k_scale, k_cache,
metadata.slot_mapping_fp8,
metadata.slot_mapping_scale,
num_tokens)

def sparse_attn_indexer(
self,
Expand Down
Loading
Loading