Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
9966962
fix: #3137 speculative decoding and multimodal input support (#3276)
maxilevi Apr 9, 2025
943218b
feat: Add Qwen2.5-VL and refactor Qwen2-VL (#3156)
yechank-nvidia Apr 9, 2025
8d164f4
update allowlist (#3428)
tburt-nv Apr 9, 2025
215fb20
chore : split GptExecutor tests out of gpt tests to reduce single tes…
peaceh-nv Apr 10, 2025
9307ff9
fix: Add nested aliases for Llama 4 (#3381)
FrankD412 Apr 10, 2025
b5473f7
waive llama3.1 8B test cases with pipeline parallelism (#3433)
QiJune Apr 10, 2025
c59abae
feat: Add Gemma3 text-only model support (#3247)
brb-nv Apr 10, 2025
fbcf954
[MLA] Deallocate tensors after use (#3286)
hlu1 Apr 10, 2025
16c8f39
feat: Support TLLM_OVERRIDE_LAYER_NUM and TLLM_TRACE_MODEL_FORWARD fo…
yuxianq Apr 10, 2025
b331d62
add sqlite to rocky container (#3114)
tburt-nv Apr 10, 2025
863d023
test: fix memory leak of tests (#3392)
xinhe-nv Apr 10, 2025
cec65bd
clean the waive.txt (#3441)
byshiue Apr 10, 2025
67949f7
Update README and add benchmarking blog for DeepSeek-R1 (#3232)
Kefeng-Duan Apr 10, 2025
5023e0d
infra: Update some test description which is out of date (#3437)
EmmaQiaoCh Apr 10, 2025
10d2d16
Waive L0 test (#3442)
yiqingy0 Apr 10, 2025
3ade937
feat: Run PyExecutor's inference flow to estimate max_num_tokens for …
HuiGao-NV Apr 10, 2025
d7a0bf9
fix: updating ucxx, which appears to avoid occasional segfaults when …
jdebache Apr 10, 2025
c5e803b
chore: code cleanup for error logging and SharedMemory in proxy.py (#…
Superjomn Apr 10, 2025
f5281ff
waive some test cases of test_llm_multi_gpu.py (#3452)
QiJune Apr 10, 2025
af05749
feat: add qwen2 moe to torch flow; fix wrong imported KvCacheConfig i…
wm2012011492 Apr 10, 2025
a6a2ae6
chore: Rename nvsmall to nemotron nas (#3447)
amitz-nv Apr 10, 2025
8300218
feat: support llama4 nope layers; support FP8 checkpoint loading; (#3…
nvzhihanj Apr 10, 2025
d7f45e5
test: disable attention DP tests for single GPU (#3395)
Tabrizian Apr 10, 2025
a8310b0
feat: trtllm-gen fp4 GEMM for pytorch workflow (#3423)
DomBrown Apr 10, 2025
5616c0d
add precommit check to github actions (#3129)
tburt-nv Apr 10, 2025
6cef100
waive a test case of llama 3.1 with torch compile (#3461)
QiJune Apr 11, 2025
1e2a339
waive unittest/_torch/multi_gpu (#3464)
QiJune Apr 11, 2025
5142c78
fix: Beam Search Diversity (#3375)
wili-65535 Apr 11, 2025
16ca457
always trigger multi gpu test to protect modeling_llama.py and modeli…
QiJune Apr 11, 2025
410f563
test: Waive torch compile tests (#3471)
syuoni Apr 11, 2025
7048db6
Local test CI
ZhanruiSunCh Apr 11, 2025
16a02ad
For test
ZhanruiSunCh Apr 11, 2025
0d6a7ca
For test again
ZhanruiSunCh Apr 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/blossom-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
startsWith(github.event.comment.body, '/bot skip --comment') ||
startsWith(github.event.comment.body, '/bot reuse-pipeline') ||
startsWith(github.event.comment.body, '/bot kill')) && contains(
fromJson('["byshiue","chuangz0","funatiq","hypdeb","jdemouth-nvidia","joyang-nv","lowsfer","Tabrizian","yweng0828","Shixiaowei02","MartinMarciniszyn","schetlur-nv","dcampora","pcastonguay","Naveassaf","lfr-0531","nekorobov","PerkzZheng","kaiyux","nv-guomingz","LinPoly","thorjohnsen","jiahanc","latency1024","tburt-nv","zeroepoch","chzblych","niukuo","ZhanruiSunCh","EmmaQiaoCh","yiqingy0","achartier","suyoggupta","amukkara","mk-nvidia","QiJune","lucaslie","davidmlw","hlu1","nvzhou","syuoni","NVGaryJi","symphonylyh","hello-11","zongfeijing","Jackch-NV","jinyangyuan-nvidia","LarryXFly","crazydemo","jaedeok-nvidia","wm2012011492","rosenrodt","zhuoyao1012","xinhe-nv","Yuening-wa","Shunkangz","zhengd-nv","yibinl-nvidia","StanleySun639","KingsleyLiu-NV","kxdc","yingcanw","BestJuly","ChristinaZ","bobboli","xueweilnvidia","kunlunl","cherichy","lucifer1004","Autumn1998","litaotju","peaceh-nv","liji-nv","SimengLiu-nv","yuxianq","yechank-nvidia","vallis-neria","DylanChen-NV","Tracin","zhhuang-nv","ISEEKYAN","xupinjie","tongyuantongyu","laikhtewari","zhuolingwang","dominicshanshan","jershi425","shifangx","StudyingShao","Superjomn","dongjiyingdjy","guangyunh-nv","wili-65535","tiffany940107","DanBlanaru","mikeiovine","djns99","ruodil","xiaoweiw-nv","xuwchen","bashimao","yizhang-nv","hyukn","nvpohanh","yuki-666","juney-nvidia","barry-delaney","Kefeng-Duan","MinaHuai","yilin-void","jtchen0528","jmydurant","katec846","CarstyYou","Njuapp","Jie-Fang","nvbrantz","inocsin","ruoqianguo","chenfeiz0326","ming-wei","eopXD","longlee0622","dongfengy","georgeliu95","evezhier","rakib-hasan","shangz-ai","JyChang012","wangsiping1997","yuanjings-nvda","tomeras91","roikoren755","amirkl94","shaharmor98","danielafrimi","amitz-nv","hijkzzz","rzilberstein-nvidia","dc3671","hchings","yuhengxnv","dongxuy04","qiaoxj07","omera-nv","DomBrown","brb-nv","FrankD412","yuhsuan-t","Fridah-nv","a-mccarthy","HuiGao-NV","alexmsettle","meenchen","sugunav14","cjluo-nv","kyleliang-nv","chang-l","WeiHaocheng","qixiang-99","BatshevaBlack","ebarilanM","xmchen1987","lingjiew","heyuhhh","netanel-haber","jiefangz-nv","wyw1267","yunruis","sklevtsov-nvidia","jgangani","pamelap-nvidia","ixlmar"]'),
fromJson('["byshiue","chuangz0","funatiq","hypdeb","jdemouth-nvidia","joyang-nv","lowsfer","Tabrizian","yweng0828","Shixiaowei02","MartinMarciniszyn","schetlur-nv","dcampora","pcastonguay","Naveassaf","lfr-0531","nekorobov","PerkzZheng","kaiyux","nv-guomingz","LinPoly","thorjohnsen","jiahanc","latency1024","tburt-nv","zeroepoch","chzblych","niukuo","ZhanruiSunCh","EmmaQiaoCh","yiqingy0","achartier","suyoggupta","amukkara","mk-nvidia","QiJune","lucaslie","davidmlw","hlu1","nvzhou","syuoni","NVGaryJi","symphonylyh","hello-11","zongfeijing","Jackch-NV","jinyangyuan-nvidia","LarryXFly","crazydemo","jaedeok-nvidia","wm2012011492","rosenrodt","zhuoyao1012","xinhe-nv","Yuening-wa","Shunkangz","zhengd-nv","yibinl-nvidia","StanleySun639","KingsleyLiu-NV","kxdc","yingcanw","BestJuly","ChristinaZ","bobboli","xueweilnvidia","kunlunl","cherichy","lucifer1004","Autumn1998","litaotju","peaceh-nv","liji-nv","SimengLiu-nv","yuxianq","yechank-nvidia","vallis-neria","DylanChen-NV","Tracin","zhhuang-nv","ISEEKYAN","xupinjie","tongyuantongyu","laikhtewari","zhuolingwang","dominicshanshan","jershi425","shifangx","StudyingShao","Superjomn","dongjiyingdjy","guangyunh-nv","wili-65535","tiffany940107","DanBlanaru","mikeiovine","djns99","ruodil","xiaoweiw-nv","xuwchen","bashimao","yizhang-nv","hyukn","nvpohanh","yuki-666","juney-nvidia","barry-delaney","Kefeng-Duan","MinaHuai","yilin-void","jtchen0528","jmydurant","katec846","CarstyYou","Njuapp","Jie-Fang","nvbrantz","inocsin","ruoqianguo","chenfeiz0326","ming-wei","eopXD","longlee0622","dongfengy","georgeliu95","evezhier","rakib-hasan","shangz-ai","JyChang012","wangsiping1997","yuanjings-nvda","tomeras91","roikoren755","amirkl94","shaharmor98","danielafrimi","amitz-nv","hijkzzz","rzilberstein-nvidia","dc3671","hchings","yuhengxnv","dongxuy04","qiaoxj07","omera-nv","DomBrown","brb-nv","FrankD412","yuhsuan-t","Fridah-nv","a-mccarthy","HuiGao-NV","alexmsettle","meenchen","sugunav14","cjluo-nv","kyleliang-nv","chang-l","WeiHaocheng","qixiang-99","BatshevaBlack","ebarilanM","xmchen1987","lingjiew","heyuhhh","netanel-haber","jiefangz-nv","wyw1267","yunruis","sklevtsov-nvidia","jgangani","pamelap-nvidia","ixlmar","GalSha","Dido0o0","rabiel","nvzhihanj","milesial","fzmu727"]'),
github.actor)
steps:
- name: Check if comment is issued by authorized person
Expand Down
22 changes: 22 additions & 0 deletions .github/workflows/l0-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# A workflow to trigger ci on hybrid infra (github + self hosted runner)
name: L0-Test
on:
issue_comment:
types: [created]
workflow_dispatch:
inputs:
sha:
Expand All @@ -28,6 +30,26 @@ on:
description: 'test results url'
required: true
jobs:
Job-trigger:
name: Start ci job
if: |
startsWith(github.event.comment.body, '/bot run') ||
startsWith(github.event.comment.body, '/bot skip --comment') ||
startsWith(github.event.comment.body, '/bot reuse-pipeline') ||
startsWith(github.event.comment.body, '/bot kill')
runs-on: [self-hosted, Linux, Jenkins]
steps:
- name: Start ci job
run: |
CI_SERVER="${{ secrets.CI_SERVER }}"
JENKINS_URL=$(echo "$CI_SERVER" | cut -d '@' -f 1)
TOKEN=$(echo "$CI_SERVER" | cut -d '@' -f 2)
sleep 100
echo '${{ toJson(github.event) }}' > githubData.json
curl -s -X POST \
-H "Content-Type: application/json" \
-d @githubData.json \
"$JENKINS_URL/generic-webhook-trigger/invoke?token=$TOKEN"
Upload-Test:
name: Upload test results
runs-on: linux-amd64-cpu4
Expand Down
43 changes: 43 additions & 0 deletions .github/workflows/precommit-check.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

name: Release Checks
on:
pull_request:
workflow_dispatch:
inputs:
ref:
description: 'commit sha to check'
required: true
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
precommit-check:
name: Pre-commit Check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.ref || github.ref }}

- uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: 'pip'

- name: Run pre-commit checks
run: |
python3 -u scripts/release_check.py
2 changes: 1 addition & 1 deletion 3rdparty/ucxx
Submodule ucxx updated 49 files
+2 −0 .github/workflows/build.yaml
+15 −0 .github/workflows/pr.yaml
+1 −1 .github/workflows/trigger-breaking-change-alert.yaml
+2 −1 .pre-commit-config.yaml
+1 −1 README.md
+9 −9 conda/environments/all_cuda-118_arch-x86_64.yaml
+11 −11 conda/environments/all_cuda-128_arch-x86_64.yaml
+7 −7 conda/recipes/ucxx/conda_build_config.yaml
+9 −21 conda/recipes/ucxx/meta.yaml
+2 −3 cpp/CMakeLists.txt
+4 −3 cpp/include/ucxx/buffer.h
+1 −1 cpp/include/ucxx/delayed_submission.h
+3 −3 cpp/include/ucxx/endpoint.h
+7 −7 cpp/include/ucxx/request.h
+3 −0 cpp/include/ucxx/request_am.h
+1 −2 cpp/include/ucxx/request_tag_multi.h
+17 −2 cpp/include/ucxx/typedefs.h
+14 −5 cpp/include/ucxx/worker.h
+1 −0 cpp/python/src/exception.cpp
+1 −0 cpp/python/src/worker.cpp
+1 −0 cpp/src/config.cpp
+2 −0 cpp/src/context.cpp
+4 −4 cpp/src/delayed_submission.cpp
+2 −1 cpp/src/endpoint.cpp
+4 −2 cpp/src/internal/request_am.cpp
+2 −1 cpp/src/listener.cpp
+3 −3 cpp/src/memory_handle.cpp
+16 −4 cpp/src/remote_key.cpp
+15 −13 cpp/src/request_am.cpp
+1 −0 cpp/src/request_data.cpp
+0 −2 cpp/src/request_endpoint_close.cpp
+0 −2 cpp/src/request_flush.cpp
+4 −6 cpp/src/request_mem.cpp
+12 −8 cpp/src/request_stream.cpp
+5 −5 cpp/src/request_tag.cpp
+13 −12 cpp/src/request_tag_multi.cpp
+3 −3 cpp/src/utils/file_descriptor.cpp
+6 −3 cpp/src/utils/sockaddr.cpp
+12 −9 cpp/src/worker.cpp
+5 −0 cpp/tests/buffer.cpp
+1 −0 cpp/tests/context.cpp
+17 −12 cpp/tests/request.cpp
+9 −4 cpp/tests/worker.cpp
+31 −27 dependencies.yaml
+4 −4 python/distributed-ucxx/pyproject.toml
+2 −2 python/libucxx/pyproject.toml
+9 −9 python/ucxx/pyproject.toml
+16 −3 python/ucxx/ucxx/_lib/libucxx.pyx
+9 −3 python/ucxx/ucxx/_lib/ucxx_api.pxd
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ TensorRT-LLM
<div align="left">

## Latest News
* [04/10] TensorRT-LLM DeepSeek R1 performance benchmarking best practices now published.
✨ [➡️ link](./docs/source/blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.md)

* [04/05] TensorRT-LLM can run Llama 4 at over 40,000 tokens per second on B200 GPUs!

![L4_perf](./docs/source/media/l4_launch_perf.png)
Expand Down
Empty file added a.txt
Empty file.
Empty file added b.txt
Empty file.
2 changes: 1 addition & 1 deletion cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class CacheTransceiver : public BaseCacheTransceiver
std::unique_ptr<DataRequester> mDataRequester;
std::map<LlmRequest*, std::future<void>> mResponderFutures;
std::vector<std::pair<LlmRequest*, std::future<void>>> mRequesterFutures;
mpi::MpiComm const *mMpiGroupComm{}, *mMpiWorldComm{};
mpi::MpiComm const *mMpiGroupComm{nullptr}, *mMpiWorldComm{nullptr};
std::shared_ptr<mpi::MpiComm> mMpiGroupTensorParaComm, mMpiGroupPipeParaComm, mMpiGroupDataComm,
mMpiGroupTPInDPComm;
executor::kv_cache::CommState const* mCommState;
Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ set(TRTLLM_LINK_LIBS
trtllm_gen_fmha
trtllm_gen_blockscale_gemm
trtllm_gen_fp8_block_scale_moe
trtllm_gen_gemm
selective_scan_src
ws_layernorm_src
fpA_intB_gemm_src
Expand Down
9 changes: 5 additions & 4 deletions cpp/tensorrt_llm/kernels/beamSearchKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,21 @@ __global__ void addCumLogProbs(T* __restrict pStage1LogProbs, float const* __res
runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM)
{
int const bid = blockIdx.x; // Index of request in batch
float const diversityRate{diversityRates[batchSlots[bid]]};
runtime::SizeType32 const slot = batchSlots[bid];
float const diversityRate{diversityRates[slot]};
T* pLocalLogProbs = pStage1LogProbs + bid * nBMIn * nBMOut * 2;

for (int i = threadIdx.x; i < nBMIn * nBMOut * 2; i += blockDim.x)
{
int const iBMIn = i / (nBMOut * 2);
if (finished[bid * nBMIn + iBMIn].isFinished())
if (finished[slot * nBMIn + iBMIn].isFinished())
{
pLocalLogProbs[i] += (i == endIds[bid]) ? 1.0f : 0.0f;
pLocalLogProbs[i] += (i == endIds[slot]) ? 1.0f : 0.0f;
}
else
{
// nBM is used in VBWS since `cumLogProbs` is initialized with kMaxBeamWidth earlier than BeamSearchLayer
pLocalLogProbs[i] += cumLogProbs[bid * nBM + iBMIn] + diversityRate * iBMIn;
pLocalLogProbs[i] += cumLogProbs[slot * nBM + iBMIn] + diversityRate * iBMIn;
}
}
return;
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/kernels/trtllmGenKernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# the License.
#

add_subdirectory(fmha)
add_subdirectory(blockscaleGemm)
add_subdirectory(fmha)
add_subdirectory(fp8BlockScaleMoe)
add_subdirectory(gemm)
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ struct TrtllmGenBlockScaleGemmOptions
void TrtllmGenBlockScaleGemmRunner::run(int32_t m, int32_t n, int32_t k, void const* a, float const* aScale,
void const* b, float const* bScale, void* c, float* cScale, CUstream stream)
{

TrtllmGenBlockScaleGemmOptions options;
options.mM = m;
options.mN = n;
Expand All @@ -98,10 +99,9 @@ void TrtllmGenBlockScaleGemmRunner::run(int32_t m, int32_t n, int32_t k, void co
options.mSliceK = mKernelInfo->sliceK;

auto params = TrtllmGenBlockScaleGemmKernelParams::setKernelParams(options, a, aScale, b, bScale, c,
nullptr /* multimemC */, cScale, nullptr /* ptrPartialSumsForSplitK */,
nullptr /* multimemPartialSumsForSplitK */, nullptr /* ptrTileBars */, nullptr /* multimemTileBars */,
nullptr /* ptrCompletionBars */, nullptr /* multimemCompletionBars */, nullptr /* ptrSplitKCompletionBars */, 0,
1);
nullptr /* ptrSfc */, nullptr /* multimemC */, cScale /* ptrScaleC */, nullptr /* ptrPartialSumsForSplitK */,
nullptr /* ptrTileBars */, nullptr /* multimemTileBars */, nullptr /* ptrCompletionBars */,
nullptr /* multimemCompletionBars */, nullptr /* ptrSplitKCompletionBars */, 0, 1);
TLLM_CHECK_WITH_INFO(sizeof(params) == 832, "Size of mismatch between trtllm-gen and trtllm");

CUlaunchConfig launch_config;
Expand Down
28 changes: 28 additions & 0 deletions cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
#

file(GLOB_RECURSE SRC_CPP *.cpp)
file(GLOB_RECURSE SRC_CU *.cu)

filter_cuda_archs("100" SRC_CPP)

add_library(trtllm_gen_gemm OBJECT ${SRC_CPP} ${SRC_CU})

target_compile_definitions(trtllm_gen_gemm PUBLIC TLLM_GEN_EXPORT_INTERFACE)

set_property(TARGET trtllm_gen_gemm PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET trtllm_gen_gemm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
105 changes: 105 additions & 0 deletions cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <vector>

#include "KernelRunner.h"
#include "tensorrt_llm/common/assert.h"
#include "trtllmGen_export/GemmInterface.h"

namespace tensorrt_llm
{
namespace kernels
{

TrtllmGenGemmRunner::TrtllmGenGemmRunner(tg::Dtype eltType, tg::Dtype outputType)
: mEltType(eltType)
, mOutputType(outputType)
{
// Select a GEMM kernel config to use
auto const gemm = gemm::GemmInterface();
auto const configs = gemm.getGemmConfigs();

std::vector<int32_t> selectedIndex;

for (size_t i = 0; i < gemm.getNumGemmConfigs(); ++i)
{
auto const options = configs[i].mOptions;

// When we include low-latency kernels we can set transposeMmaOutput via constructor
if (options.mDtypeElt == eltType && options.mDtypeC == outputType && !options.mTransposeMmaOutput)
{
selectedIndex.push_back(i);
}
}

TLLM_CHECK_WITH_INFO(selectedIndex.size() != 0, "No kernel found for the given output type");
TLLM_CHECK_WITH_INFO(selectedIndex.size() == 1, "Multiple kernels found for the given output type");

mGemmConfig = &configs[selectedIndex[0]];
}

size_t TrtllmGenGemmRunner::getWorkspaceSizeInBytes(
int32_t m, int32_t n, int32_t k, tg::Dtype eltType, tg::Dtype outputType) const
{
gemm::GemmData gemmData;
gemmData.mProblemDimensions.mM = m;
gemmData.mProblemDimensions.mN = n;
gemmData.mProblemDimensions.mK = k;

auto gemm = gemm::GemmInterface();

return gemm.getWorkspaceSizeInBytes(*mGemmConfig, gemmData);
}

void TrtllmGenGemmRunner::run(int32_t m, int32_t n, int32_t k, void const* a, float const* aScale, void const* b,
float const* bScale, void* c, float* cScale, void* workspace, CUstream stream, int device)
{
auto gemm = gemm::GemmInterface();

gemm::GemmData gemmData;

// Dims
gemmData.mProblemDimensions.mM = m;
gemmData.mProblemDimensions.mN = n;
gemmData.mProblemDimensions.mK = k;

// Inputs
gemmData.mInputBuffers.mPtrA = a;
gemmData.mInputBuffers.mPtrSfA = aScale;
gemmData.mInputBuffers.mPtrB = b;
gemmData.mInputBuffers.mPtrSfB = bScale;
gemmData.mInputBuffers.mPtrScaleC = cScale;

// Outputs
gemmData.mOutputBuffers.mPtrC = c;

auto isValidConfig = gemm.isValidConfig(*mGemmConfig, gemmData);
TLLM_CHECK_WITH_INFO(isValidConfig, "Invalid GEMM config selected!");

cudaDeviceProp deviceProperties;
cudaGetDeviceProperties(&deviceProperties, device);

// FIXME once we start using all-reduce in the epilogue of the gemm this can be moved elsewhere
gemm.runInitBeforeWorldSync(*mGemmConfig, gemmData, static_cast<void*>(stream));

auto const err = gemm.run(*mGemmConfig, workspace, gemmData, static_cast<void*>(stream), deviceProperties);

TLLM_CHECK_WITH_INFO(err == 0, "Error occurred when running GEMM!");
}

} // namespace kernels
} // namespace tensorrt_llm
48 changes: 48 additions & 0 deletions cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cuda.h>

#include "trtllmGen_export/GemmOptions.h"
#include "trtllmGen_export/trtllm/gen/DtypeDecl.h"

namespace tensorrt_llm
{
namespace kernels
{

namespace tg = trtllm::gen;

class TrtllmGenGemmRunner
{
public:
explicit TrtllmGenGemmRunner(tg::Dtype eltType, tg::Dtype outputType);

[[nodiscard]] size_t getWorkspaceSizeInBytes(
int32_t m, int32_t n, int32_t k, tg::Dtype eltType, tg::Dtype outputType) const;

void run(int32_t m, int32_t n, int32_t k, void const* a, float const* aScale, void const* b, float const* bScale,
void* c, float* cScale, void* workspace, CUstream stream, int device);

private:
tg::Dtype mEltType;
tg::Dtype mOutputType;
gemm::GemmConfig const* mGemmConfig;
};
} // namespace kernels
} // namespace tensorrt_llm
Loading