Description
Checklist
- 1. I have searched related issues but cannot get the expected help.
- 2. The bug has not been fixed in the latest version.
- 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
- 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
- 5. Please use English, otherwise it will be closed.
Describe the bug
Thank you for the great work.
When using the following code in batch decoding scenarios, the hidden_states
for prefilling (the non-cached part, indexed by [0]) in meta info
is not correctly assigned to the corresponding output batch, but randomly stacked to one or few outputs. I only output 1 token at a time. The cache-missed input should also be 1.
outputs = llm.generate(
input_ids=active_tokens,
sampling_params=sampling_params,
return_hidden_states=True
)
for output in outputs:
hidden_state = output["meta_info"]["hidden_states"][0][-1]
When batch=1, it works fine. When batch=4, you can see the hidden state shape adds up to 4, but not 1-1-1-1 as expected, but something like 1-3-0-0
To provide better context, I want to acquire the last-layer-hidden-states at each decoding step, and modify the current output token based on it with another NN. If there are better ways to do this other than hack the for-loop of decoding outside the model and rely on prefix caching, please also let me know.
Reproduction
import torch
from transformers import AutoTokenizer
import sglang as sgl
def main():
prompts = [
"Hello.",
"Who is the president of the United States?",
"What is the capital of France?",
"What is the future of AI?",
]
# Create an LLM.
model_path = "Qwen/Qwen2.5-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_path)
llm = sgl.Engine(
model_path=model_path,
skip_tokenizer_init=True,
)
sampling_params = {
"temperature": 0.0,
"max_new_tokens": 1, # Generate one token at a time
}
# Create the system prompt message
messages = []
for prompt in prompts:
message = [{
"role": "user",
"content": prompt
}]
messages.append(message)
# Prepare system text
prompts_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
prompts_token = [tokenizer.encode(prompt) for prompt in prompts_text]
# Get EOS token id
eos_token_id = tokenizer.eos_token_id
# Initialize index tracking
active_indices = list(range(len(prompts_text)))
finished_indices = []
# Initialize tokens and tracking
active_tokens = [tokenizer.encode(prompt) for prompt in prompts_text]
finished_tokens = []
# Initialize hidden states tracking
active_hidden_states = [[] for _ in range(len(active_tokens))]
finished_hidden_states = []
# Generate tokens one by one until all prompts reach EOS or max limit
max_iterations = 128 # Safety limit
iteration = 0
while active_tokens and iteration < max_iterations:
# Generate a single token for all active prompts
outputs = llm.generate(
input_ids=active_tokens,
sampling_params=sampling_params,
return_hidden_states=True
)
# Process each output
to_remove = [] # Indices to remove from active lists
for batch_idx in range(len(active_tokens)):
output = outputs[batch_idx]
generated_token = output["output_ids"][-1]
hidden_state = output["meta_info"]["hidden_states"][0][-1] if len(output["meta_info"]["hidden_states"][0]) > 0 else None
assert len(output["meta_info"]["hidden_states"]) == 1
if len(output["meta_info"]["hidden_states"][0]) > 0:
print(f"Prompt {active_indices[batch_idx]} - Iteration {iteration} - Hidden state shape: {len(output["meta_info"]["hidden_states"][0])}")
else:
print(f"Prompt {active_indices[batch_idx]} - Iteration {iteration} - Hidden state shape: 0")
# Get prompt index
prompt_idx = active_indices[batch_idx]
# Add the new hidden state
active_hidden_states[batch_idx].append(hidden_state)
# Append the token to input for next iteration
active_tokens[batch_idx].append(generated_token)
# Check if EOS token
if generated_token == eos_token_id:
print(f"Prompt {prompt_idx} - EOS token generated at iteration {iteration}")
# Move to finished collections
finished_indices.append(prompt_idx)
finished_tokens.append(active_tokens[batch_idx])
finished_hidden_states.append(active_hidden_states[batch_idx])
# Mark for removal from active lists
to_remove.append(batch_idx)
# Decode the completed sequence
completed_text = tokenizer.decode(active_tokens[batch_idx])
print(f"Completed text for prompt {prompt_idx}: {completed_text}")
else:
# print(f"Prompt {prompt_idx} - Iteration {iteration}: Generated token ID {generated_token}")
pass
# Remove finished items from active lists (in reverse to avoid index shifting)
for idx in sorted(to_remove, reverse=True):
del active_tokens[idx]
del active_hidden_states[idx]
del active_indices[idx]
iteration += 1
# Print final results for all prompts
prompt_indices = finished_indices + active_indices
final_outputs = []
for i, prompt_idx in enumerate(prompt_indices):
print("===============================")
print(f"Prompt: {prompts_text[prompt_idx]}")
# Get the hidden states (either from finished or still active)
if prompt_idx in finished_indices:
# Find in finished collections
finished_idx = finished_indices.index(prompt_idx)
hidden_states = finished_hidden_states[finished_idx]
token_ids = finished_tokens[finished_idx]
else:
# Find in active collections
active_idx = active_indices.index(prompt_idx)
hidden_states = active_hidden_states[active_idx]
token_ids = active_tokens[active_idx]
if hidden_states and len(hidden_states) > 0:
# TODO: debug purpose, remove later
new_hidden_states = []
for hidden_state in hidden_states:
if hidden_state is not None:
new_hidden_states.append(hidden_state)
hidden_states = new_hidden_states
hidden_states_tensor = torch.tensor(hidden_states)
print(f"Hidden states shape for prompt {prompt_idx}: {hidden_states_tensor.shape}")
final_outputs.append({
"prompt": prompts_text[prompt_idx],
"input_ids": prompts_token[prompt_idx],
"output_ids": token_ids[len(prompts_token[prompt_idx]):],
"hidden_states": hidden_states_tensor
})
for output in final_outputs:
prompt_text = tokenizer.decode(output["input_ids"])
output_text = tokenizer.decode(output["output_ids"])
print(f"Prompt: {prompt_text}")
print(f"Output: {output_text}")
print(f"Hidden states shape: {output['hidden_states'].shape}")
print()
llm.shutdown()
if __name__ == "__main__":
main()
Environment
Python: 3.12.9 | packaged by Anaconda, Inc. | (main, Feb 6 2025, 18:56:27) [GCC 11.2.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA A800-SXM4-80GB
GPU 0,1,2,3,4,5,6,7 Compute Capability: 8.0
CUDA_HOME: /share/public/user_dir/futianyu/cuda-12.2
NVCC: Cuda compilation tools, release 12.2, V12.2.91
CUDA Driver Version: 535.161.08
PyTorch: 2.5.1+cu124
sglang: 0.4.4.post1
sgl_kernel: 0.0.5
flashinfer: 0.2.3+cu124torch2.5
triton: 3.1.0
transformers: 4.48.3
torchao: 0.9.0
numpy: 1.26.4
aiohttp: 3.11.14
fastapi: 0.115.12
hf_transfer: 0.1.9
huggingface_hub: 0.29.3
interegular: 0.3.3
modelscope: 1.24.0
orjson: 3.10.16
packaging: 24.2
psutil: 7.0.0
pydantic: 2.10.6
multipart: 0.0.20
zmq: 26.3.0
uvicorn: 0.34.0
uvloop: 0.21.0
vllm: 0.7.2
openai: 1.68.2
tiktoken: 0.9.0
anthropic: 0.49.0
decord: 0.6.0
NVIDIA Topology:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 NIC0 NIC1 NIC2 NIC3 NIC4 NIC5 NIC6NIC7 NIC8 NIC9 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV8 NV8 NV8 NV8 NV8 NV8 NV8 PXB PXB NODE NODE SYS SYS SYSSYS SYS NODE 0-31,64-95 0 N/A
GPU1 NV8 X NV8 NV8 NV8 NV8 NV8 NV8 PXB PXB NODE NODE SYS SYS SYSSYS SYS NODE 0-31,64-95 0 N/A
GPU2 NV8 NV8 X NV8 NV8 NV8 NV8 NV8 NODE NODE PXB PXB SYS SYS SYSSYS SYS NODE 0-31,64-95 0 N/A
GPU3 NV8 NV8 NV8 X NV8 NV8 NV8 NV8 NODE NODE PXB PXB SYS SYS SYSSYS SYS NODE 0-31,64-95 0 N/A
GPU4 NV8 NV8 NV8 NV8 X NV8 NV8 NV8 SYS SYS SYS SYS NODE PXB PXBNODE NODE SYS 32-63,96-127 1 N/A
GPU5 NV8 NV8 NV8 NV8 NV8 X NV8 NV8 SYS SYS SYS SYS NODE PXB PXBNODE NODE SYS 32-63,96-127 1 N/A
GPU6 NV8 NV8 NV8 NV8 NV8 NV8 X NV8 SYS SYS SYS SYS NODE NODE NODEPXB PXB SYS 32-63,96-127 1 N/A
GPU7 NV8 NV8 NV8 NV8 NV8 NV8 NV8 X SYS SYS SYS SYS NODE NODE NODEPXB PXB SYS 32-63,96-127 1 N/A
NIC0 PXB PXB NODE NODE SYS SYS SYS SYS X PIX NODE NODE SYS SYS SYSSYS SYS NODE
NIC1 PXB PXB NODE NODE SYS SYS SYS SYS PIX X NODE NODE SYS SYS SYSSYS SYS NODE
NIC2 NODE NODE PXB PXB SYS SYS SYS SYS NODE NODE X PIX SYS SYS SYSSYS SYS NODE
NIC3 NODE NODE PXB PXB SYS SYS SYS SYS NODE NODE PIX X SYS SYS SYSSYS SYS NODE
NIC4 SYS SYS SYS SYS NODE NODE NODE NODE SYS SYS SYS SYS X NODE NODENODE NODE SYS
NIC5 SYS SYS SYS SYS PXB PXB NODE NODE SYS SYS SYS SYS NODE X PIXNODE NODE SYS
NIC6 SYS SYS SYS SYS PXB PXB NODE NODE SYS SYS SYS SYS NODE PIX X NODE NODE SYS
NIC7 SYS SYS SYS SYS NODE NODE PXB PXB SYS SYS SYS SYS NODE NODE NODE X PXB SYS
NIC8 SYS SYS SYS SYS NODE NODE PXB PXB SYS SYS SYS SYS NODE NODE NODEPXB X SYS
NIC9 NODE NODE NODE NODE SYS SYS SYS SYS NODE NODE NODE NODE SYS SYS SYSSYS SYS X
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
NIC Legend:
NIC0: mlx5_0
NIC1: mlx5_1
NIC2: mlx5_2
NIC3: mlx5_3
NIC4: mlx5_6
NIC5: mlx5_7
NIC6: mlx5_8
NIC7: mlx5_9
NIC8: mlx5_10
NIC9: mlx5_bond_0
ulimit soft: 1048576