Skip to content

[Bug] Hidden states is not correctly acquired for batched processing #4997

Open
@fuvty

Description

@fuvty

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
  • 5. Please use English, otherwise it will be closed.

Describe the bug

Thank you for the great work.
When using the following code in batch decoding scenarios, the hidden_states for prefilling (the non-cached part, indexed by [0]) in meta info is not correctly assigned to the corresponding output batch, but randomly stacked to one or few outputs. I only output 1 token at a time. The cache-missed input should also be 1.

outputs = llm.generate(
            input_ids=active_tokens,
            sampling_params=sampling_params,
            return_hidden_states=True
        )
for output in outputs:
  hidden_state = output["meta_info"]["hidden_states"][0][-1]

When batch=1, it works fine. When batch=4, you can see the hidden state shape adds up to 4, but not 1-1-1-1 as expected, but something like 1-3-0-0

To provide better context, I want to acquire the last-layer-hidden-states at each decoding step, and modify the current output token based on it with another NN. If there are better ways to do this other than hack the for-loop of decoding outside the model and rely on prefix caching, please also let me know.

Reproduction

import torch
from transformers import AutoTokenizer
import sglang as sgl


def main():
    prompts = [
        "Hello.",
        "Who is the president of the United States?",
        "What is the capital of France?",
        "What is the future of AI?",
    ]
    # Create an LLM.
    model_path = "Qwen/Qwen2.5-1.5B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    llm = sgl.Engine(
        model_path=model_path,
        skip_tokenizer_init=True,
    )

    sampling_params = {
        "temperature": 0.0,
        "max_new_tokens": 1,  # Generate one token at a time
    }

    # Create the system prompt message
    messages = []
    for prompt in prompts:
        message = [{
            "role": "user",
            "content": prompt
        }]
        messages.append(message)

    # Prepare system text
    prompts_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    prompts_token = [tokenizer.encode(prompt) for prompt in prompts_text]

    # Get EOS token id
    eos_token_id = tokenizer.eos_token_id
    
    # Initialize index tracking
    active_indices = list(range(len(prompts_text)))
    finished_indices = []

    # Initialize tokens and tracking
    active_tokens = [tokenizer.encode(prompt) for prompt in prompts_text]
    finished_tokens = []

    # Initialize hidden states tracking
    active_hidden_states = [[] for _ in range(len(active_tokens))]
    finished_hidden_states = []
    
    # Generate tokens one by one until all prompts reach EOS or max limit
    max_iterations = 128  # Safety limit
    iteration = 0
    
    while active_tokens and iteration < max_iterations:
        # Generate a single token for all active prompts
        outputs = llm.generate(
            input_ids=active_tokens,
            sampling_params=sampling_params,
            return_hidden_states=True
        )
        
        # Process each output
        to_remove = []  # Indices to remove from active lists
        
        for batch_idx in range(len(active_tokens)):
            output = outputs[batch_idx]
            generated_token = output["output_ids"][-1]
            hidden_state = output["meta_info"]["hidden_states"][0][-1] if len(output["meta_info"]["hidden_states"][0]) > 0 else None

            assert len(output["meta_info"]["hidden_states"]) == 1
            if len(output["meta_info"]["hidden_states"][0]) > 0:
                print(f"Prompt {active_indices[batch_idx]} - Iteration {iteration} - Hidden state shape: {len(output["meta_info"]["hidden_states"][0])}")
            else:
                print(f"Prompt {active_indices[batch_idx]} - Iteration {iteration} - Hidden state shape: 0")

            # Get prompt index
            prompt_idx = active_indices[batch_idx]
            
            # Add the new hidden state
            active_hidden_states[batch_idx].append(hidden_state)

            # Append the token to input for next iteration
            active_tokens[batch_idx].append(generated_token)
            
            # Check if EOS token
            if generated_token == eos_token_id:
                print(f"Prompt {prompt_idx} - EOS token generated at iteration {iteration}")
                
                # Move to finished collections
                finished_indices.append(prompt_idx)
                finished_tokens.append(active_tokens[batch_idx])
                finished_hidden_states.append(active_hidden_states[batch_idx])
                
                # Mark for removal from active lists
                to_remove.append(batch_idx)
                
                # Decode the completed sequence
                completed_text = tokenizer.decode(active_tokens[batch_idx])
                print(f"Completed text for prompt {prompt_idx}: {completed_text}")
            else:
                # print(f"Prompt {prompt_idx} - Iteration {iteration}: Generated token ID {generated_token}")
                pass
        
        # Remove finished items from active lists (in reverse to avoid index shifting)
        for idx in sorted(to_remove, reverse=True):
            del active_tokens[idx]
            del active_hidden_states[idx]
            del active_indices[idx]
            
        iteration += 1
    
    # Print final results for all prompts
    prompt_indices = finished_indices + active_indices
    final_outputs = []
    for i, prompt_idx in enumerate(prompt_indices):
        print("===============================")
        print(f"Prompt: {prompts_text[prompt_idx]}")
        
        # Get the hidden states (either from finished or still active)
        if prompt_idx in finished_indices:
            # Find in finished collections
            finished_idx = finished_indices.index(prompt_idx)
            hidden_states = finished_hidden_states[finished_idx]
            token_ids = finished_tokens[finished_idx]
        else:
            # Find in active collections
            active_idx = active_indices.index(prompt_idx)
            hidden_states = active_hidden_states[active_idx]
            token_ids = active_tokens[active_idx]

        if hidden_states and len(hidden_states) > 0:
            # TODO: debug purpose, remove later
            new_hidden_states = []
            for hidden_state in hidden_states:
                if hidden_state is not None:
                    new_hidden_states.append(hidden_state)
            hidden_states = new_hidden_states
            hidden_states_tensor = torch.tensor(hidden_states)
            print(f"Hidden states shape for prompt {prompt_idx}: {hidden_states_tensor.shape}")
        
        final_outputs.append({
            "prompt": prompts_text[prompt_idx],
            "input_ids": prompts_token[prompt_idx],
            "output_ids": token_ids[len(prompts_token[prompt_idx]):],
            "hidden_states": hidden_states_tensor
        })

    for output in final_outputs:
        prompt_text = tokenizer.decode(output["input_ids"])
        output_text = tokenizer.decode(output["output_ids"])
        print(f"Prompt: {prompt_text}")
        print(f"Output: {output_text}")
        print(f"Hidden states shape: {output['hidden_states'].shape}")
        print()

    llm.shutdown()

if __name__ == "__main__":
    main()

Environment

Python: 3.12.9 | packaged by Anaconda, Inc. | (main, Feb 6 2025, 18:56:27) [GCC 11.2.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA A800-SXM4-80GB
GPU 0,1,2,3,4,5,6,7 Compute Capability: 8.0
CUDA_HOME: /share/public/user_dir/futianyu/cuda-12.2
NVCC: Cuda compilation tools, release 12.2, V12.2.91
CUDA Driver Version: 535.161.08
PyTorch: 2.5.1+cu124
sglang: 0.4.4.post1
sgl_kernel: 0.0.5
flashinfer: 0.2.3+cu124torch2.5
triton: 3.1.0
transformers: 4.48.3
torchao: 0.9.0
numpy: 1.26.4
aiohttp: 3.11.14
fastapi: 0.115.12
hf_transfer: 0.1.9
huggingface_hub: 0.29.3
interegular: 0.3.3
modelscope: 1.24.0
orjson: 3.10.16
packaging: 24.2
psutil: 7.0.0
pydantic: 2.10.6
multipart: 0.0.20
zmq: 26.3.0
uvicorn: 0.34.0
uvloop: 0.21.0
vllm: 0.7.2
openai: 1.68.2
tiktoken: 0.9.0
anthropic: 0.49.0
decord: 0.6.0
NVIDIA Topology:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 NIC0 NIC1 NIC2 NIC3 NIC4 NIC5 NIC6NIC7 NIC8 NIC9 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV8 NV8 NV8 NV8 NV8 NV8 NV8 PXB PXB NODE NODE SYS SYS SYSSYS SYS NODE 0-31,64-95 0 N/A
GPU1 NV8 X NV8 NV8 NV8 NV8 NV8 NV8 PXB PXB NODE NODE SYS SYS SYSSYS SYS NODE 0-31,64-95 0 N/A
GPU2 NV8 NV8 X NV8 NV8 NV8 NV8 NV8 NODE NODE PXB PXB SYS SYS SYSSYS SYS NODE 0-31,64-95 0 N/A
GPU3 NV8 NV8 NV8 X NV8 NV8 NV8 NV8 NODE NODE PXB PXB SYS SYS SYSSYS SYS NODE 0-31,64-95 0 N/A
GPU4 NV8 NV8 NV8 NV8 X NV8 NV8 NV8 SYS SYS SYS SYS NODE PXB PXBNODE NODE SYS 32-63,96-127 1 N/A
GPU5 NV8 NV8 NV8 NV8 NV8 X NV8 NV8 SYS SYS SYS SYS NODE PXB PXBNODE NODE SYS 32-63,96-127 1 N/A
GPU6 NV8 NV8 NV8 NV8 NV8 NV8 X NV8 SYS SYS SYS SYS NODE NODE NODEPXB PXB SYS 32-63,96-127 1 N/A
GPU7 NV8 NV8 NV8 NV8 NV8 NV8 NV8 X SYS SYS SYS SYS NODE NODE NODEPXB PXB SYS 32-63,96-127 1 N/A
NIC0 PXB PXB NODE NODE SYS SYS SYS SYS X PIX NODE NODE SYS SYS SYSSYS SYS NODE
NIC1 PXB PXB NODE NODE SYS SYS SYS SYS PIX X NODE NODE SYS SYS SYSSYS SYS NODE
NIC2 NODE NODE PXB PXB SYS SYS SYS SYS NODE NODE X PIX SYS SYS SYSSYS SYS NODE
NIC3 NODE NODE PXB PXB SYS SYS SYS SYS NODE NODE PIX X SYS SYS SYSSYS SYS NODE
NIC4 SYS SYS SYS SYS NODE NODE NODE NODE SYS SYS SYS SYS X NODE NODENODE NODE SYS
NIC5 SYS SYS SYS SYS PXB PXB NODE NODE SYS SYS SYS SYS NODE X PIXNODE NODE SYS
NIC6 SYS SYS SYS SYS PXB PXB NODE NODE SYS SYS SYS SYS NODE PIX X NODE NODE SYS
NIC7 SYS SYS SYS SYS NODE NODE PXB PXB SYS SYS SYS SYS NODE NODE NODE X PXB SYS
NIC8 SYS SYS SYS SYS NODE NODE PXB PXB SYS SYS SYS SYS NODE NODE NODEPXB X SYS
NIC9 NODE NODE NODE NODE SYS SYS SYS SYS NODE NODE NODE NODE SYS SYS SYSSYS SYS X

Legend:

X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks

NIC Legend:

NIC0: mlx5_0
NIC1: mlx5_1
NIC2: mlx5_2
NIC3: mlx5_3
NIC4: mlx5_6
NIC5: mlx5_7
NIC6: mlx5_8
NIC7: mlx5_9
NIC8: mlx5_10
NIC9: mlx5_bond_0

ulimit soft: 1048576

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions