Skip to content

Commit 294d6ff

Browse files
authored
Merge branch 'main' into utils-refactor
2 parents 69975f2 + 53609e5 commit 294d6ff

10 files changed

Lines changed: 206 additions & 228 deletions

File tree

python/sglang/jit_kernel/csrc/diffusion/timestep_embedding.cuh

Lines changed: 0 additions & 173 deletions
This file was deleted.

python/sglang/jit_kernel/timestep_embedding.py

Lines changed: 0 additions & 44 deletions
This file was deleted.

python/sglang/multimodal_gen/runtime/layers/visual_embedding.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,10 @@
1919
)
2020

2121
try:
22-
from sglang.jit_kernel.timestep_embedding import (
23-
timestep_embedding as timestep_embedding_cuda,
24-
)
22+
from sgl_kernel.elementwise import timestep_embedding as timestep_embedding_cuda
2523
except Exception as _e:
2624
# Fallback to diffusers implementation so downstream code can still run
27-
# even if `jit_kernel` is not available.
25+
# even if `sgl_kernel` is not installed/available.
2826
timestep_embedding_cuda = _get_timestep_embedding
2927

3028
from sglang.multimodal_gen.runtime.layers.activation import get_act_fn
@@ -88,13 +86,14 @@ def forward(self, x):
8886

8987
class Timesteps(_Timesteps):
9088
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
91-
return timestep_embedding_cuda(
89+
t_emb = timestep_embedding_cuda(
9290
timesteps,
9391
self.num_channels,
9492
flip_sin_to_cos=self.flip_sin_to_cos,
9593
downscale_freq_shift=self.downscale_freq_shift,
9694
scale=self.scale,
9795
)
96+
return t_emb
9897

9998

10099
class CombinedTimestepGuidanceTextProjEmbeddings(

sgl-kernel/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ set(SOURCES
282282
"csrc/elementwise/rope.cu"
283283
"csrc/elementwise/pos_enc.cu"
284284
"csrc/elementwise/topk.cu"
285+
"csrc/sgl_diffusion/elementwise/timestep_embedding.cu"
285286
"csrc/expert_specialization/es_fp8_blockwise.cu"
286287
"csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu"
287288
"csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu"

sgl-kernel/csrc/common_extension.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,19 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
609609

610610
m.def("fast_hadamard_transform_40N(Tensor x, float scale) -> Tensor");
611611
m.impl("fast_hadamard_transform_40N", torch::kCUDA, &fast_hadamard_transform_40N);
612+
613+
/*
614+
* From csrc/sgl_diffusion/elementwise
615+
*/
616+
m.def(
617+
"timestep_embedding(Tensor input,"
618+
"Tensor output,"
619+
"int dim,"
620+
"bool flip_sin_to_cos,"
621+
"float downscale_freq_shift,"
622+
"float scale,"
623+
"int max_period) -> Tensor");
624+
m.impl("timestep_embedding", torch::kCUDA, &timestep_embedding);
612625
}
613626

614627
REGISTER_EXTENSION(common_ops)

0 commit comments

Comments
 (0)