Skip to content

Commit d0e17de

Browse files
committed
Disaggregation works
1 parent c019f20 commit d0e17de

File tree

5 files changed

+31
-17
lines changed

5 files changed

+31
-17
lines changed

shortfin/python/shortfin_apps/llm/cli.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
8282
default=1,
8383
help="Number of workers to use when running in `offline` mode.",
8484
)
85-
parser.add_argument(
86-
"--disaggregate",
87-
action="store_true",
88-
help="Disaggregate the prefill and decode invocations to separate HIP streams.",
89-
)
90-
85+
9186

9287
def parse_args(argv):
9388
parser = argparse.ArgumentParser()

shortfin/python/shortfin_apps/llm/components/batcher.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(
6969
self.scheduler = Scheduler(ideal_batch_size=self.ideal_batch_size)
7070
self.cache = DeviceArrayCache(fiber.device(0))
7171
self.program_isolation = program_isolation
72+
self.exec_fiber = exec_fiber
7273

7374
def handle_inference_request(self, request):
7475
"""Handle an inference request."""
@@ -120,7 +121,7 @@ async def board_flights(self):
120121
scheduled = []
121122
for job in to_schedule:
122123
scheduled = scheduled + job
123-
self.board(cache, self.fiber, job)
124+
self.board(cache, self.exec_fiber, job)
124125
logger.debug("Post boarding cache state: %r", cache)
125126

126127
pending = set(pending) - set(scheduled)

shortfin/python/shortfin_apps/llm/components/generate.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,23 +88,28 @@ def __init__(
8888
)
8989
self.streamed_tokens_index = 0
9090
self._status_tracker = status_tracker
91-
92-
async def run(self):
93-
exec_req = LlmInferenceExecRequest(
91+
self.exec_req = LlmInferenceExecRequest(
9492
phase=InferencePhase.PREFILL,
9593
input_token_ids=self.input_token_ids,
9694
rid=self.gen_req.rid,
9795
status_tracker=self._status_tracker,
9896
)
99-
exec_req._cache = self.client.prefill_batcher.page_cache
97+
98+
async def run(self):
99+
self.exec_req._cache = self.client.prefill_batcher.page_cache
100100
try:
101101
# Prefill result.
102-
await self.token_selector.prefill(exec_req)
102+
await self.token_selector.prefill(self.exec_req)
103103

104104
# Decode loop.
105-
await self.token_selector.decode(exec_req)
105+
await self.token_selector.decode(self.exec_req)
106106
finally:
107-
exec_req.free_cache_pages()
107+
self.exec_req.completed.set_success()
108+
self.exec_req.free_cache_pages()
109+
110+
async def await_completion(self):
111+
await self.exec_req.completed
112+
return self.index
108113

109114
def results_callback(self, result: int | list[list[int]]):
110115
if is_multi_response(self.decode_config):
@@ -225,6 +230,7 @@ async def run(self):
225230
else:
226231
input_batch = self.tokenize()
227232

233+
pending = []
228234
for index, input_tokens in enumerate(input_batch):
229235
decode_config = copy(self.decode_config)
230236
decode_config.update_from_sampling_params(
@@ -273,11 +279,17 @@ async def run(self):
273279
fiber=fiber,
274280
)
275281
gen_processes.append(gen_process)
282+
pending.append(asyncio.create_task(gen_process.await_completion()))
276283
gen_process.launch()
277284

278-
await asyncio.gather(*gen_processes)
279-
if not self.responder.is_disconnected():
280-
self.generate_response(gen_processes, streaming)
285+
while pending:
286+
done, pending = await asyncio.wait(
287+
pending, return_when=asyncio.FIRST_COMPLETED
288+
)
289+
for task in done:
290+
idx = await task
291+
if not self.responder.is_disconnected():
292+
self.generate_response([gen_processes[idx]], streaming)
281293
finally:
282294
# Remove request from queue when done
283295
self.service.remove_from_queue(self.decode_config.num_beams)

shortfin/python/shortfin_apps/llm/components/messages.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
self.input_token_ids = input_token_ids
3838
self.prompt_length = len(input_token_ids)
3939
self.done = sf.VoidFuture()
40+
self.completed = sf.VoidFuture()
4041
self.rid = rid
4142
# Unique `instance_id` for token selection strategies that may need
4243
# to differentiate between an original req and a copy of a req.

shortfin/python/shortfin_apps/llm/server.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ def add_service_args(parser: argparse.ArgumentParser):
128128
default=1,
129129
help="Number of fibers to use per worker.",
130130
)
131+
parser.add_argument(
132+
"--disaggregate",
133+
action="store_true",
134+
help="Disaggregate the prefill and decode invocations to separate HIP streams.",
135+
)
131136

132137

133138
def parse_args(argv):

0 commit comments

Comments
 (0)