Open
Description
I have implemented an inference API using ONNX Runtime and FastAPI to process multiple prompts in batches, with the goal of improving efficiency. However, I've observed that performance is significantly slower with batching compared to processing each prompt individually. When I set the batch_size back to 1, the API performs optimally.
Here is my code:
import onnxruntime_genai as og
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from pydantic import BaseModel
from typing import Optional, List
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=['*']
)
model = og.Model('/home/rad/bags/models/cuda/cuda-int4-awq-block-128')
tokenizer = og.Tokenizer(model)
def model_run(prompts: List[str], search_options):
input_tokens = tokenizer.encode_batch(prompts)
params = og.GeneratorParams(model)
params.set_search_options(**search_options)
params.input_ids = input_tokens
output_tokens = model.generate(params)
out = tokenizer.decode_batch(output_tokens)
return out
def infer(list_prompt_input: List[str], max_length = 2000):
search_options = {
'max_length': max_length,
'temperature': 0.0,
'top_p': 0.95,
'top_k': 0.95,
'repetition_penalty': 1.05,
}
chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'
# prompt = f'{chat_template.format(input=list_prompt_input)}'
prompts = [chat_template.format(input=prompt) for prompt in list_prompt_input]
outputs = model_run(prompts,search_options)
result = []
for idx,output in enumerate(outputs):
result.append(output.split(list_prompt_input[idx])[-1].strip())
return result
class BatchInferenceRequest(BaseModel):
prompts: List[str]
max_length: Optional[int] = 2000
batch_size: int = 2
@app.post("/llm_infer")
async def llm_infer(request: BatchInferenceRequest): # batching much slower than without batch
max_batch_size = request.batch_size
result = []
import time
start_time = time.time()
for i in range(0, len(request.prompts), max_batch_size):
batch_prompts = request.prompts[i:i + max_batch_size]
outputs = infer(batch_prompts, request.max_length)
result.extend(outputs)
end_time = time.time()
execution_time = end_time - start_time
return {"results": result,"time":execution_time}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=5555)