Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 34 additions & 32 deletions jenkins/BuildDockerImage.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -706,39 +706,41 @@ pipeline {
}
steps {
script {
container("python3") {
trtllm_utils.llmExecStepWithRetry(this, script: "pip3 install --upgrade pip")
trtllm_utils.llmExecStepWithRetry(this, script: "pip3 install --upgrade requests")
def nspect_commit = "4cb9c0c42d44ebeeba1e40d2c3eb6aab6fb90173"
def override_commit = env."NSPECT_OVERRIDE_${nspect_commit}"
if (override_commit) {
echo "Overriding nspect_commit with value from environment variable \$NSPECT_OVERRIDE_${nspect_commit}: ${override_commit}"
catchError(buildResult: 'FAILURE', stageResult: 'FAILURE') {
container("python3") {
trtllm_utils.llmExecStepWithRetry(this, script: "pip3 install --upgrade pip")
trtllm_utils.llmExecStepWithRetry(this, script: "pip3 install --upgrade requests")
def nspect_commit = "4cb9c0c42d44ebeeba1e40d2c3eb6aab6fb90173"
def override_commit = env."NSPECT_OVERRIDE_${nspect_commit}"
if (override_commit) {
echo "Overriding nspect_commit with value from environment variable \$NSPECT_OVERRIDE_${nspect_commit}: ${override_commit}"
nspect_commit = override_commit
}
withCredentials([string(credentialsId: "TRTLLM_NSPECT_REPO", variable: "NSPECT_REPO")]) {
trtllm_utils.checkoutSource("${NSPECT_REPO}", nspect_commit, "nspect")
}
def nspect_env = params.nspect_env ? params.nspect_env : "prod"
def program_version_name = params.program_version_name ? params.program_version_name : "PostMerge"
def cmd = """./nspect/nspect.py \
--env ${nspect_env} \
--nspect_id ${params.nspect_id} \
--program_version_name '${program_version_name}' \
"""
if (params.register_images) {
cmd += "--register "
}
if (params.osrb_ticket) {
cmd += "--osrb_ticket ${params.osrb_ticket} "
}
if (params.wait_success_seconds) {
cmd += "--check_launch_api "
cmd += "--wait_success ${params.wait_success_seconds} "
}
cmd += "--image "
cmd += imageKeyToTag.values().join(" ")
withCredentials([usernamePassword(credentialsId: "NSPECT_CLIENT-${nspect_env}", usernameVariable: 'NSPECT_CLIENT_ID', passwordVariable: 'NSPECT_CLIENT_SECRET')]) {
trtllm_utils.llmExecStepWithRetry(this, script: cmd, sleepInSecs: 600, numRetries: 6, shortCommondRunTimeMax: 7200)
}
withCredentials([string(credentialsId: "TRTLLM_NSPECT_REPO", variable: "NSPECT_REPO")]) {
trtllm_utils.checkoutSource("${NSPECT_REPO}", nspect_commit, "nspect")
}
def nspect_env = params.nspect_env ? params.nspect_env : "prod"
def program_version_name = params.program_version_name ? params.program_version_name : "PostMerge"
def cmd = """./nspect/nspect.py \
--env ${nspect_env} \
--nspect_id ${params.nspect_id} \
--program_version_name '${program_version_name}' \
"""
if (params.register_images) {
cmd += "--register "
}
if (params.osrb_ticket) {
cmd += "--osrb_ticket ${params.osrb_ticket} "
}
if (params.wait_success_seconds) {
cmd += "--check_launch_api "
cmd += "--wait_success ${params.wait_success_seconds} "
}
cmd += "--image "
cmd += imageKeyToTag.values().join(" ")
withCredentials([usernamePassword(credentialsId: "NSPECT_CLIENT-${nspect_env}", usernameVariable: 'NSPECT_CLIENT_ID', passwordVariable: 'NSPECT_CLIENT_SECRET')]) {
trtllm_utils.llmExecStepWithRetry(this, script: cmd, sleepInSecs: 600, numRetries: 0, shortCommondRunTimeMax: 7200)
}
}
}
}
Expand Down
51 changes: 35 additions & 16 deletions jenkins/L0_MergeRequest.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ STAGE_CHOICE_NORMAL = "normal"
STAGE_CHOICE_SKIP = "skip"
STAGE_CHOICE_IGNORE = "ignore"

RELESE_CHECK_CHOICE = env.releaseCheckChoice ? env.releaseCheckChoice : STAGE_CHOICE_NORMAL
RELEASE_CHECK_CHOICE = env.releaseCheckChoice ? env.releaseCheckChoice : STAGE_CHOICE_NORMAL
BUILD_CHECK_CHOICE = env.buildCheckChoice ? env.buildCheckChoice : STAGE_CHOICE_NORMAL
X86_TEST_CHOICE = env.x86TestChoice ? env.x86TestChoice : STAGE_CHOICE_NORMAL
SBSA_TEST_CHOICE = env.SBSATestChoice ? env.SBSATestChoice : STAGE_CHOICE_NORMAL

Expand Down Expand Up @@ -437,10 +438,14 @@ def launchReleaseCheck(pipeline, globalVars)
sh "cd ${LLM_ROOT} && confidentiality-scan \$(find . -type f ${ignoreList.collect { "-not -path \"${it}\"" }.join(' ')}) 2>&1 | tee scan.log"
def lastLine = sh(script: "tail -n 1 ${LLM_ROOT}/scan.log", returnStdout: true).trim()
if (lastLine.toLowerCase().contains("error")) {
error "Guardwords Scan Failed."
error "GUARDWORDS_WARN: Guardwords Scan Failed."
}
} catch (Exception e) {
} catch (InterruptedException e) {
throw e
} catch (Exception e) {
catchError(buildResult: 'SUCCESS', stageResult: 'UNSTABLE') {
error "Release Check failed (warn-only): ${e.getMessage()}"
}
} finally {
trtllm_utils.uploadArtifacts("${LLM_ROOT}/scan.log", "${UPLOAD_PATH}/guardwords-scan-results/")
echo "Guardwords Scan Results: https://urm.nvidia.com/artifactory/${UPLOAD_PATH}/guardwords-scan-results/scan.log"
Expand Down Expand Up @@ -488,7 +493,7 @@ def launchReleaseCheck(pipeline, globalVars)
stageName = "Release-Check"
trtllm_utils.launchKubernetesPod(pipeline, createKubernetesPodConfig(image, "package"), "trt-llm", {
stage("[${stageName}] Run") {
if (RELESE_CHECK_CHOICE == STAGE_CHOICE_SKIP) {
if (RELEASE_CHECK_CHOICE == STAGE_CHOICE_SKIP) {
echo "Release Check job is skipped due to Jenkins configuration"
return
}
Expand All @@ -498,7 +503,7 @@ def launchReleaseCheck(pipeline, globalVars)
} catch (InterruptedException e) {
throw e
} catch (Exception e) {
if (RELESE_CHECK_CHOICE == STAGE_CHOICE_IGNORE) {
if (RELEASE_CHECK_CHOICE == STAGE_CHOICE_IGNORE) {
catchError(
buildResult: 'SUCCESS',
stageResult: 'FAILURE') {
Expand Down Expand Up @@ -1275,19 +1280,33 @@ def launchStages(pipeline, reuseBuild, testFilter, enableFailFast, globalVars)
script {
def testStageName = "[Build-Docker-Images] Remote Run"
stage(testStageName) {
def branch = env.gitlabBranch ? env.gitlabBranch : "main"
if (globalVars[GITHUB_PR_API_URL]) {
branch = "github-pr-" + globalVars[GITHUB_PR_API_URL].split('/').last()
}
try {
def branch = env.gitlabBranch ? env.gitlabBranch : "main"
if (globalVars[GITHUB_PR_API_URL]) {
branch = "github-pr-" + globalVars[GITHUB_PR_API_URL].split('/').last()
}

def additionalParameters = [
'branch': branch,
'action': "push",
'triggerType': env.JOB_NAME ==~ /.*PostMerge.*/ ? "post-merge" : "pre-merge",
'runSanityCheck': env.JOB_NAME ==~ /.*PostMerge.*/ ? true : false,
]
def additionalParameters = [
'branch': branch,
'action': "push",
'triggerType': env.JOB_NAME ==~ /.*PostMerge.*/ ? "post-merge" : "pre-merge",
'runSanityCheck': env.JOB_NAME ==~ /.*PostMerge.*/ ? true : false,
]

launchJob(pipeline, "/LLM/helpers/BuildDockerImages", false, enableFailFast, globalVars, "x86_64", additionalParameters)
launchJob(pipeline, "/LLM/helpers/BuildDockerImages", false, enableFailFast, globalVars, "x86_64", additionalParameters)
} catch (InterruptedException e) {
throw e
} catch (Exception e) {
if (BUILD_CHECK_CHOICE == STAGE_CHOICE_IGNORE) {
catchError(
buildResult: 'SUCCESS',
stageResult: 'FAILURE') {
error "Build-Docker-Images job failed but ignored due to Jenkins configuration"
}
} else {
throw e
}
}
}
}
}
Expand Down
119 changes: 78 additions & 41 deletions tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import functools
import os

import torch
from einops import rearrange, repeat
Expand Down Expand Up @@ -42,6 +43,7 @@
from .layernorm_gated import fused_gated_rmsnorm_quant_shape_ok
from .selective_state_update import \
selective_state_update as selective_state_update_native
from .selective_state_update import selective_state_update_mtp_ssm_cache_trtllm
from .ssd_combined import mamba_chunk_scan_combined


Expand Down Expand Up @@ -165,6 +167,9 @@ def __init__(
and self._mamba_ssm_cache_dtype == torch.float16)
self._philox_rounds = config.quant_config.mamba_ssm_philox_rounds

self._use_mtp_custom_op = os.environ.get(
"TRTLLM_MAMBA2_MTP_USE_CUSTOM_OP", "0") == "1"

if self._use_flashinfer:
logger.info_once("Using flashinfer for selective state update",
key="selective_state_update")
Expand Down Expand Up @@ -472,50 +477,82 @@ def convert_dt():
D = repeat(self.D, "h -> h p", p=self.head_dim)
if is_target_verify:
intermediate_ssm_states = layer_cache.intermediate_ssm
# Build kwargs for MTP selective_state_update
mtp_kwargs = dict(
z=None,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_d[:num_decodes],
out=preallocated_ssm_out_d.view(
num_decodes,
draft_token_num,
self.num_heads // self.tp_size,
self.head_dim,
),
disable_state_update=True,
intermediate_states_buffer=intermediate_ssm_states,
cache_steps=draft_token_num,
intermediate_state_indices=intermediate_state_indices,
x_d_mtp = x_d.view(
num_decodes,
draft_token_num,
self.num_heads // self.tp_size,
self.head_dim,
)
dt_d_mtp = dt_d.view(
num_decodes,
draft_token_num,
self.num_heads // self.tp_size,
self.head_dim,
)
B_d_mtp = B_d.view(num_decodes, draft_token_num,
self.tp_ngroups, -1)
C_d_mtp = C_d.view(num_decodes, draft_token_num,
self.tp_ngroups, -1)
out_mtp = preallocated_ssm_out_d.view(
num_decodes,
draft_token_num,
self.num_heads // self.tp_size,
self.head_dim,
)
if self._use_stochastic_rounding:
mtp_kwargs['rand_seed'] = torch.randint(0,
2**62, (1, ),
device=x_d.device,
dtype=torch.int64)
mtp_kwargs['philox_rounds'] = self._philox_rounds

self.selective_state_update_func(
ssm_states,
x_d.view(
num_decodes,
draft_token_num,
self.num_heads // self.tp_size,
self.head_dim,
),
dt_d.view(
num_decodes,
if self._use_mtp_custom_op and not self._use_stochastic_rounding:
# Use the TRT-LLM CUDA custom op for MTP SSM cache
# update. This path does not support stochastic
# rounding (rand_seed / philox_rounds).
selective_state_update_mtp_ssm_cache_trtllm(
ssm_states,
x_d_mtp,
dt_d_mtp,
A,
B_d_mtp,
C_d_mtp,
out_mtp,
intermediate_ssm_states,
draft_token_num,
self.num_heads // self.tp_size,
self.head_dim,
),
A,
B_d.view(num_decodes, draft_token_num, self.tp_ngroups, -1),
C_d.view(num_decodes, draft_token_num, self.tp_ngroups, -1),
D,
**mtp_kwargs,
)
D=D,
z=None,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_d[:num_decodes],
disable_state_update=True,
intermediate_state_indices=intermediate_state_indices,
)
else:
# Build kwargs for MTP selective_state_update
mtp_kwargs = dict(
z=None,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_d[:num_decodes],
out=out_mtp,
disable_state_update=True,
intermediate_states_buffer=intermediate_ssm_states,
cache_steps=draft_token_num,
intermediate_state_indices=intermediate_state_indices,
)
if self._use_stochastic_rounding:
mtp_kwargs['rand_seed'] = torch.randint(
0,
2**62, (1, ),
device=x_d.device,
dtype=torch.int64)
mtp_kwargs['philox_rounds'] = self._philox_rounds

self.selective_state_update_func(
ssm_states,
x_d_mtp,
dt_d_mtp,
A,
B_d_mtp,
C_d_mtp,
D,
**mtp_kwargs,
)
else:
# Build kwargs for selective_state_update
ssu_kwargs = dict(
Expand Down
Loading
Loading