@@ -589,7 +589,7 @@ index 1cdf65b91..4783cd18f 100644
589589 buf_numel_per_page: tl.constexpr,
590590 index_head_dim: tl.constexpr,
591591diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py
592- index ca54a931b..258407c71 100644
592+ index ca54a931b..961d5f62a 100644
593593--- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py
594594+++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py
595595@@ -4,6 +4,7 @@ import contextlib
@@ -612,7 +612,17 @@ index ca54a931b..258407c71 100644
612612 device=get_global_server_args().device,
613613 )
614614 self.block_size = block_size
615- @@ -982,6 +986,9 @@ class Indexer(MultiPlatformOp):
615+ @@ -244,6 +248,9 @@ class Indexer(MultiPlatformOp):
616+ x = x.to(self.weights_proj.weight.dtype)
617+ weights, _ = self.weights_proj(x)
618+ weights = weights.float()
619+ + if weights.shape[1] < q_scale.shape[1]:
620+ + assert q_scale.shape[1] % weights.shape[1] == 0
621+ + weights = weights.repeat_interleave(q_scale.shape[1] // weights.shape[1], dim=1)
622+ weights = weights * self.n_heads**-0.5
623+ weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
624+ return weights
625+ @@ -982,15 +989,24 @@ class Indexer(MultiPlatformOp):
616626 query, key = self._get_q_k_bf16(
617627 q_lora, x, positions, enable_dual_stream, forward_batch=forward_batch
618628 )
@@ -622,7 +632,12 @@ index ca54a931b..258407c71 100644
622632 q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
623633 with torch.cuda.stream(self.alt_stream):
624634 k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
625- @@ -991,6 +998,9 @@ class Indexer(MultiPlatformOp):
635+ current_stream.wait_stream(self.alt_stream)
636+ + if weights.shape[1] < q_scale.shape[1]:
637+ + assert q_scale.shape[1] % weights.shape[1] == 0
638+ + weights = weights.repeat_interleave(q_scale.shape[1] // weights.shape[1], dim=1)
639+ weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
640+ else:
626641 query, key = self._get_q_k_bf16(
627642 q_lora, x, positions, enable_dual_stream, forward_batch=forward_batch
628643 )
@@ -1593,7 +1608,7 @@ index f2ffa9909..6e4d1d460 100644
15931608 self,
15941609 obj: InitWeightsSendGroupForRemoteInstanceReqInput,
15951610diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
1596- index 0914a5230..d637041b0 100644
1611+ index 0914a5230..2a5819856 100644
15971612--- a/python/sglang/srt/managers/tokenizer_manager.py
15981613+++ b/python/sglang/srt/managers/tokenizer_manager.py
15991614@@ -324,8 +324,12 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi
@@ -1641,6 +1656,24 @@ index 0914a5230..d637041b0 100644
16411656 trace_slice_start(RequestStage.TOKENIZER_DISPATCH, obj.rid)
16421657 tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid)
16431658 tokenized_obj = wrap_shm_features(tokenized_obj)
1659+ @@ -1327,7 +1348,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi
1660+ async with self.is_pause_cond:
1661+ self.is_pause = True
1662+ if obj.mode != "abort":
1663+ - await self.send_to_scheduler.send_pyobj(obj)
1664+ + self.send_to_scheduler.send_pyobj(obj)
1665+ else:
1666+ # we are using the model_update_lock to check if there is still on-going requests.
1667+ while True:
1668+ @@ -1341,7 +1362,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi
1669+ async def continue_generation(self, obj: ContinueGenerationReqInput):
1670+ async with self.is_pause_cond:
1671+ self.is_pause = False
1672+ - await self.send_to_scheduler.send_pyobj(obj)
1673+ + self.send_to_scheduler.send_pyobj(obj)
1674+ self.is_pause_cond.notify_all()
1675+
1676+ async def update_weights_from_disk(
16441677diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py
16451678index 86b009df4..16ebd52ae 100644
16461679--- a/python/sglang/srt/managers/tp_worker.py
0 commit comments