Skip to content

Add multiprocess support #1394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 27 additions & 28 deletions shortfin/python/shortfin_apps/llm/benchmark_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,10 @@ async def run_benchmark(
num_concurrent_requests: int = 64,
token_selection_strategy: str = "multi_greedy",
endpoint: str = "http://localhost:8080",
streaming = False,
multi_hypothesis = False,
best_of_n = 8,
top_p = 0.95
streaming=False,
multi_hypothesis=False,
best_of_n=8,
top_p=0.95,
):
"""Execute the benchmark and return raw data."""
client = LLMClient(base_url=endpoint, stream=streaming)
Expand All @@ -281,9 +281,9 @@ async def run_benchmark(
}

params = {
"max_completion_tokens": output_token_length,
"token_selection_strategy": token_selection_strategy,
"num_beams": 8,
"max_completion_tokens": output_token_length,
"token_selection_strategy": token_selection_strategy,
"num_beams": 8,
}

if multi_hypothesis:
Expand Down Expand Up @@ -342,10 +342,10 @@ async def continuous_load_test(
token_selection_strategy: str,
endpoint: str,
duration: int = 60, # Run for 60 seconds by default
streaming = False,
multi_hypothesis = False,
best_of_n = 8,
top_p = 0.95
streaming=False,
multi_hypothesis=False,
best_of_n=8,
top_p=0.95,
) -> Dict[str, Any]:
"""Run a continuous load test with a single client sending requests continuously."""
client = LLMClient(base_url=endpoint, stream=streaming)
Expand All @@ -357,16 +357,15 @@ async def continuous_load_test(
num_requests = 0

params = {
"max_completion_tokens": output_token_length,
"token_selection_strategy": token_selection_strategy,
"num_beams": 8,
"max_completion_tokens": output_token_length,
"token_selection_strategy": token_selection_strategy,
"num_beams": 8,
}

if multi_hypothesis:
params["b_of_n"] = best_of_n
params["top_p"] = top_p


while time.perf_counter() < end_time:
try:
request_start = time.perf_counter()
Expand Down Expand Up @@ -399,10 +398,10 @@ async def calculate_throughput(
token_selection_strategy: str,
endpoint: str,
duration: int = 60, # Run for 60 seconds by default
streaming = False,
multi_hypothesis = False,
best_of_n = 8,
top_p = 0.95
streaming=False,
multi_hypothesis=False,
best_of_n=8,
top_p=0.95,
):
"""Calculate throughput by running continuous load tests with multiple concurrent clients."""
print(
Expand All @@ -422,7 +421,7 @@ async def calculate_throughput(
streaming=streaming,
multi_hypothesis=multi_hypothesis,
best_of_n=best_of_n,
top_p=top_p
top_p=top_p,
)
)

Expand Down Expand Up @@ -477,10 +476,10 @@ async def run_all_benchmarks(
endpoint: str = "http://localhost:8080",
num_throughput_runs: int = 20,
results_dir: str = "results",
multi_hypothesis = False,
streaming = False,
best_of_n = 8,
top_p = 0.95
multi_hypothesis=False,
streaming=False,
best_of_n=8,
top_p=0.95,
):
all_results = []
throughput_results = []
Expand Down Expand Up @@ -509,7 +508,7 @@ async def run_all_benchmarks(
streaming=streaming,
multi_hypothesis=multi_hypothesis,
best_of_n=best_of_n,
top_p=top_p
top_p=top_p,
)
result = compute_benchmark_results(benchmark_data)
all_results.append(result)
Expand Down Expand Up @@ -655,12 +654,12 @@ async def run_all_benchmarks(
)
parser.add_argument(
"--multi-hypothesis",
action='store_true',
action="store_true",
help="Enable multi hypothesis",
)
parser.add_argument(
"--stream",
action='store_true',
action="store_true",
help="Enable streaming",
)
parser.add_argument(
Expand Down Expand Up @@ -694,6 +693,6 @@ async def run_all_benchmarks(
streaming=args.stream,
multi_hypothesis=args.multi_hypothesis,
best_of_n=args.best_of_n,
top_p = args.top_p
top_p=args.top_p,
)
)
118 changes: 71 additions & 47 deletions shortfin/python/shortfin_apps/llm/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ class ClientGenerateBatchProcess(sf.Process):
"decode_batcher",
"gen_req",
"prefill_batcher",
"responder",
# "responder",
"response_handler",
"tokenizer",
"decode_config",
"service",
Expand All @@ -145,14 +146,18 @@ def __init__(
self,
service: LlmGenerateService,
gen_req: GenerateReqInput,
responder: FastAPIResponder,
# responder: FastAPIResponder,
response_handler: callable,
# takes a single arg which is bytes/None, or a list of bytes/None
# for streaming
fiber: sf.Fiber | None = None,
):
super().__init__(
fiber=service.fiber_pool.fibers[0].fiber if fiber is None else fiber
)
self.gen_req = gen_req
self.responder = responder
# self.responder = responder
self.response_handler = response_handler
self.tokenizer = service.tokenizer
self.prefill_batcher = service.prefill_batcher
self.decode_batcher = service.decode_batcher
Expand All @@ -168,54 +173,70 @@ def __init__(
async def run(self):
logger.debug("Started ClientBatchGenerateProcess: %r", self)

try:
streaming = self.gen_req.stream
self.responder.start_response()
if streaming:
self.responder.stream_start()

# Launch all individual generate processes and wait for them to finish.
gen_processes = []
input_ids = self.gen_req.input_ids
is_pretokenized = input_ids is not None

if is_pretokenized:
input_batch = [input_ids] if self.gen_req.is_single else input_ids
else:
input_batch = self.tokenize()
for index, input_tokens in enumerate(input_batch):
decode_config = copy(self.decode_config)
decode_config.update_from_sampling_params(
self.gen_req.sampling_params
if self.gen_req.is_single
else self.gen_req.sampling_params[index]
)
gen_process = GenerateItemProcess(
self,
self.gen_req,
index,
self.gen_req.text
if self.gen_req.is_single
else self.gen_req.text[index],
input_tokens if is_pretokenized else input_tokens.ids,
eos_token_id=self.tokenizer.eos_token_id,
decode_config=decode_config,
)
gen_processes.append(gen_process)
gen_process.launch()

await asyncio.gather(*gen_processes)
self.generate_response(gen_processes, streaming)
finally:
# Try to add request to queue
if not self.service.add_to_queue():
error_response = JSONResponse(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
content={
"error": "Server queue is full. Please try again later.",
"code": "QUEUE_FULL",
"current_size": self.service.current_queue_size,
"max_size": self.service.max_queue_size,
},
)
self.responder.send_response(error_response)
self.responder.ensure_response()
return

# try:
streaming = self.gen_req.stream
# self.responder.start_response()
# if streaming:
# self.responder.stream_start()

# Launch all individual generate processes and wait for them to finish.
gen_processes = []
input_ids = self.gen_req.input_ids
is_pretokenized = input_ids is not None

if is_pretokenized:
input_batch = [input_ids] if self.gen_req.is_single else input_ids
else:
input_batch = self.tokenize()
for index, input_tokens in enumerate(input_batch):
decode_config = copy(self.decode_config)
decode_config.update_from_sampling_params(
self.gen_req.sampling_params
if self.gen_req.is_single
else self.gen_req.sampling_params[index]
)
gen_process = GenerateItemProcess(
self,
self.gen_req,
index,
self.gen_req.text
if self.gen_req.is_single
else self.gen_req.text[index],
input_tokens if is_pretokenized else input_tokens.ids,
eos_token_id=self.tokenizer.eos_token_id,
decode_config=decode_config,
)
gen_processes.append(gen_process)
gen_process.launch()

await asyncio.gather(*gen_processes)
self.generate_response(gen_processes, streaming)
# finally:
# self.responder.ensure_response()

def generate_response(
self, gen_processes: List[GenerateItemProcess], streaming: bool
):
if streaming:
logger.info("Responding to streaming batch")
self.responder.stream_part(b"data: [DONE]\n\n")
self.responder.stream_part(None)
# self.responder.stream_part(b"data: [DONE]\n\n")
# self.responder.stream_part(None)
self.response_handler([b"data: [DONE]\n\n", None])
return

logging.debug("Responding to one shot batch")
Expand All @@ -225,7 +246,8 @@ def generate_response(
result_tokens = result_tokens[0]
out = io.BytesIO()
out.write(bytes(json.dumps(result_tokens), "utf-8"))
self.responder.send_response(out.getvalue())
# self.responder.send_response(out.getvalue())
self.response_handler(out.getvalue())
return

response_map = {}
Expand Down Expand Up @@ -254,7 +276,8 @@ def generate_response(
response = json.dumps(response)
out = io.BytesIO()
out.write(response.encode())
self.responder.send_response(out.getvalue())
# self.responder.send_response(out.getvalue())
self.response_handler(out.getvalue())

def _respond_multi_responses(
self, result_token_ids: List[List[int]], out: io.BytesIO
Expand Down Expand Up @@ -292,7 +315,8 @@ def stream_results(self, gen_process: GenerateItemProcess):
out.write(f"data({rid}): ".encode())
out.write(str(result_tokens[0]).encode())
out.write(b"\n\n")
self.responder.stream_part(out.getvalue())
# self.responder.stream_part(out.getvalue())
self.response_handler(out.getvalue())
gen_process.streamed_tokens_index += len(result_tokens)

def tokenize(self) -> list[Encoding]:
Expand Down
Loading
Loading