Skip to content

Commit bd70add

Browse files
authored
[docker] fix sglang upgrade bug (#1639)
1 parent 0257bd6 commit bd70add

File tree

3 files changed

+39
-5
lines changed

3 files changed

+39
-5
lines changed

docker/patch/latest/sglang.patch

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ index 1cdf65b91..4783cd18f 100644
589589
buf_numel_per_page: tl.constexpr,
590590
index_head_dim: tl.constexpr,
591591
diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py
592-
index ca54a931b..258407c71 100644
592+
index ca54a931b..961d5f62a 100644
593593
--- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py
594594
+++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py
595595
@@ -4,6 +4,7 @@ import contextlib
@@ -612,7 +612,17 @@ index ca54a931b..258407c71 100644
612612
device=get_global_server_args().device,
613613
)
614614
self.block_size = block_size
615-
@@ -982,6 +986,9 @@ class Indexer(MultiPlatformOp):
615+
@@ -244,6 +248,9 @@ class Indexer(MultiPlatformOp):
616+
x = x.to(self.weights_proj.weight.dtype)
617+
weights, _ = self.weights_proj(x)
618+
weights = weights.float()
619+
+ if weights.shape[1] < q_scale.shape[1]:
620+
+ assert q_scale.shape[1] % weights.shape[1] == 0
621+
+ weights = weights.repeat_interleave(q_scale.shape[1] // weights.shape[1], dim=1)
622+
weights = weights * self.n_heads**-0.5
623+
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
624+
return weights
625+
@@ -982,15 +989,24 @@ class Indexer(MultiPlatformOp):
616626
query, key = self._get_q_k_bf16(
617627
q_lora, x, positions, enable_dual_stream, forward_batch=forward_batch
618628
)
@@ -622,7 +632,12 @@ index ca54a931b..258407c71 100644
622632
q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
623633
with torch.cuda.stream(self.alt_stream):
624634
k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
625-
@@ -991,6 +998,9 @@ class Indexer(MultiPlatformOp):
635+
current_stream.wait_stream(self.alt_stream)
636+
+ if weights.shape[1] < q_scale.shape[1]:
637+
+ assert q_scale.shape[1] % weights.shape[1] == 0
638+
+ weights = weights.repeat_interleave(q_scale.shape[1] // weights.shape[1], dim=1)
639+
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
640+
else:
626641
query, key = self._get_q_k_bf16(
627642
q_lora, x, positions, enable_dual_stream, forward_batch=forward_batch
628643
)
@@ -1593,7 +1608,7 @@ index f2ffa9909..6e4d1d460 100644
15931608
self,
15941609
obj: InitWeightsSendGroupForRemoteInstanceReqInput,
15951610
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
1596-
index 0914a5230..d637041b0 100644
1611+
index 0914a5230..2a5819856 100644
15971612
--- a/python/sglang/srt/managers/tokenizer_manager.py
15981613
+++ b/python/sglang/srt/managers/tokenizer_manager.py
15991614
@@ -324,8 +324,12 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi
@@ -1641,6 +1656,24 @@ index 0914a5230..d637041b0 100644
16411656
trace_slice_start(RequestStage.TOKENIZER_DISPATCH, obj.rid)
16421657
tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid)
16431658
tokenized_obj = wrap_shm_features(tokenized_obj)
1659+
@@ -1327,7 +1348,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi
1660+
async with self.is_pause_cond:
1661+
self.is_pause = True
1662+
if obj.mode != "abort":
1663+
- await self.send_to_scheduler.send_pyobj(obj)
1664+
+ self.send_to_scheduler.send_pyobj(obj)
1665+
else:
1666+
# we are using the model_update_lock to check if there is still on-going requests.
1667+
while True:
1668+
@@ -1341,7 +1362,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi
1669+
async def continue_generation(self, obj: ContinueGenerationReqInput):
1670+
async with self.is_pause_cond:
1671+
self.is_pause = False
1672+
- await self.send_to_scheduler.send_pyobj(obj)
1673+
+ self.send_to_scheduler.send_pyobj(obj)
1674+
self.is_pause_cond.notify_all()
1675+
1676+
async def update_weights_from_disk(
16441677
diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py
16451678
index 86b009df4..16ebd52ae 100644
16461679
--- a/python/sglang/srt/managers/tp_worker.py

docker/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
nightly-dev-20260226a
1+
nightly-dev-20260227a

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ datasets
44
httpx[http2]
55
mcp[cli]
66
memray # needed for debugging (but is lightweight), we can put it to dev mode when using pyproject.toml
7+
numba
78
omegaconf
89
pillow
910
pylatexenc

0 commit comments

Comments
 (0)