Skip to content

Inference with batching is significantly slower than without batching. #714

Open
@Jester6136

Description

@Jester6136

I have implemented an inference API using ONNX Runtime and FastAPI to process multiple prompts in batches, with the goal of improving efficiency. However, I've observed that performance is significantly slower with batching compared to processing each prompt individually. When I set the batch_size back to 1, the API performs optimally.

Here is my code:

import onnxruntime_genai as og
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from pydantic import BaseModel
from typing import Optional, List

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=['*']
)
model = og.Model('/home/rad/bags/models/cuda/cuda-int4-awq-block-128')
tokenizer = og.Tokenizer(model)


def model_run(prompts: List[str], search_options):
    input_tokens = tokenizer.encode_batch(prompts)
    params = og.GeneratorParams(model)
    params.set_search_options(**search_options)
    params.input_ids = input_tokens
    output_tokens = model.generate(params)
    out = tokenizer.decode_batch(output_tokens)
    return out


def infer(list_prompt_input: List[str], max_length = 2000):
    search_options = {
        'max_length': max_length,
        'temperature': 0.0,
        'top_p': 0.95,
        'top_k': 0.95,
        'repetition_penalty': 1.05,
    }
    chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'
    # prompt = f'{chat_template.format(input=list_prompt_input)}'
    prompts = [chat_template.format(input=prompt) for prompt in list_prompt_input]
    outputs = model_run(prompts,search_options)
    
    result = []
    for idx,output in enumerate(outputs):
        result.append(output.split(list_prompt_input[idx])[-1].strip())
    return result

class BatchInferenceRequest(BaseModel):
    prompts: List[str]
    max_length: Optional[int] = 2000
    batch_size: int = 2

@app.post("/llm_infer")
async def llm_infer(request: BatchInferenceRequest):    # batching much slower than without batch
    max_batch_size = request.batch_size
    result = []
    
    import time
    start_time = time.time()
    for i in range(0, len(request.prompts), max_batch_size):
        batch_prompts = request.prompts[i:i + max_batch_size]
        outputs = infer(batch_prompts, request.max_length)
        result.extend(outputs)
    end_time = time.time()
    execution_time = end_time - start_time 
    return {"results": result,"time":execution_time}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=5555)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions