Skip to content

Commit fdcd28a

Browse files
e-martirosianElizaveta MartirosianElizaveta Martirosian
authored
[NPU] Enable consistency checking for diffusion tests (#27283)
Co-authored-by: Elizaveta Martirosian <elizaveta.martirosian@gmail.com> Co-authored-by: Elizaveta Martirosian <you@example.com>
1 parent c6be251 commit fdcd28a

3 files changed

Lines changed: 22 additions & 6 deletions

File tree

python/sglang/multimodal_gen/runtime/layers/rotary_embedding/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,15 @@ def apply_flashinfer_rope_qk_inplace(
128128
cos = cos_sin_cache[positions, :half_size].to(q.dtype)
129129
sin = cos_sin_cache[positions, half_size:].to(q.dtype)
130130

131+
if current_platform.is_npu():
132+
q_flat = q.reshape(bsz * seqlen, q_heads, d)
133+
k_flat = k.reshape(bsz * seqlen, k_heads, d)
134+
q_rot = apply_rotary_embedding(q_flat, cos, sin, interleaved=not is_neox)
135+
k_rot = apply_rotary_embedding(k_flat, cos, sin, interleaved=not is_neox)
136+
return q_rot.view(bsz, seqlen, q_heads, d), k_rot.view(
137+
bsz, seqlen, k_heads, d
138+
)
139+
131140
def apply_rope_prefix(x: torch.Tensor, num_heads: int) -> torch.Tensor:
132141
x_flat = x.reshape(bsz * seqlen, num_heads, d)
133142
x_rot = x_flat[..., :rope_dim]

python/sglang/multimodal_gen/test/server/ascend/testcase_configs_npu.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
extras=EXTRAS_DISABLE_WARMUP,
3636
),
3737
T2I_sampling_params,
38-
run_consistency_check=False,
3938
),
4039
# === Text to Video (T2V) ===
4140
DiffusionTestCase(
@@ -47,7 +46,6 @@
4746
DiffusionSamplingParams(
4847
prompt=T2V_PROMPT,
4948
),
50-
run_consistency_check=False,
5149
),
5250
]
5351

@@ -62,7 +60,6 @@
6260
extras=EXTRAS_DISABLE_WARMUP,
6361
),
6462
T2I_sampling_params,
65-
run_consistency_check=False,
6663
),
6764
DiffusionTestCase(
6865
"qwen_image_t2i_2npu",
@@ -75,7 +72,6 @@
7572
extras=EXTRAS_DISABLE_WARMUP,
7673
),
7774
T2I_sampling_params,
78-
run_consistency_check=False,
7975
),
8076
# === Text to Video (T2V) ===
8177
DiffusionTestCase(
@@ -90,7 +86,6 @@
9086
DiffusionSamplingParams(
9187
prompt=T2V_PROMPT,
9288
),
93-
run_consistency_check=False,
9489
),
9590
]
9691

python/sglang/multimodal_gen/test/test_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import requests
2222
from PIL import Image, ImageDraw, ImageFont
2323

24+
from sglang.multimodal_gen.runtime.platforms import current_platform
2425
from sglang.multimodal_gen.runtime.utils.common import get_bool_env_var
2526
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
2627
from sglang.multimodal_gen.runtime.utils.perf_logger import (
@@ -34,6 +35,10 @@
3435
logger = init_logger(__name__)
3536

3637
SGL_TEST_FILES_CI_DATA_REVISION = "caa56302ccf2d289e4488ed06d952edf5d2314cf"
38+
39+
if current_platform.is_npu():
40+
SGL_TEST_FILES_CI_DATA_REVISION = "670d66a8a290b62c0c3c077b3e9b0f4a4d9a44e7"
41+
3742
SGL_TEST_FILES_CONSISTENCY_GT_ROOT = (
3843
"https://raw.githubusercontent.com/"
3944
f"sgl-project/ci-data/{SGL_TEST_FILES_CI_DATA_REVISION}/"
@@ -51,7 +56,14 @@
5156
SGL_TEST_FILES_SGLANG_CONSISTENCY_GT_BASE_ASCEND = (
5257
f"{SGL_TEST_FILES_CONSISTENCY_GT_ROOT}/sglang_generated/ascend"
5358
)
59+
5460
SGL_TEST_FILES_CONSISTENCY_GT_BASE = SGL_TEST_FILES_SGLANG_CONSISTENCY_GT_BASE
61+
62+
if current_platform.is_npu():
63+
SGL_TEST_FILES_CONSISTENCY_GT_BASE = (
64+
SGL_TEST_FILES_SGLANG_CONSISTENCY_GT_BASE_ASCEND
65+
)
66+
5567
SGL_TEST_FILES_CONSISTENCY_GT_BASES = (
5668
SGL_TEST_FILES_OFFICIAL_CONSISTENCY_GT_BASE,
5769
SGL_TEST_FILES_SGLANG_CONSISTENCY_GT_BASE,
@@ -1018,7 +1030,7 @@ def _find_remote_consistency_gt_files(
10181030
bases = SGL_TEST_FILES_CONSISTENCY_GT_BASES
10191031
else:
10201032
# Avoid accidentally comparing non-comparable CI cases against official GT.
1021-
bases = (SGL_TEST_FILES_SGLANG_CONSISTENCY_GT_BASE,)
1033+
bases = (SGL_TEST_FILES_CONSISTENCY_GT_BASE,)
10221034
for base_url in bases:
10231035
candidates = _remote_consistency_gt_candidates(
10241036
base_url, case_id, num_gpus, is_video, output_format

0 commit comments

Comments
 (0)