Skip to content

Commit 401da69

Browse files
committed
[Shortfin][LLM] Add initial support for dissagregated invocations
1 parent 1c25ae4 commit 401da69

File tree

6 files changed

+71
-16
lines changed

6 files changed

+71
-16
lines changed

shortfin/python/shortfin_apps/llm/cli.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,18 @@ def add_service_args(parser: argparse.ArgumentParser):
134134
parser.add_argument(
135135
"--benchmark",
136136
action="store_true",
137-
help="Perform a benchmarking run for throughput",
137+
help="Perform a benchmarking run for throughput.",
138138
)
139139
parser.add_argument(
140140
"--benchmark_tasks",
141141
type=int,
142142
default=None,
143-
help="Workload size to benchmark with",
143+
help="Workload size to benchmark with.",
144+
)
145+
parser.add_argument(
146+
"--disaggregate",
147+
action="store_true",
148+
help="Disaggregate the prefill and decode invocations to separate HIP streams.",
144149
)
145150

146151

shortfin/python/shortfin_apps/llm/components/batcher.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
functions: dict[int, sf.ProgramFunction],
5454
ideal_batch_size: int,
5555
program_isolation: str,
56+
exec_fiber: Fiber,
5657
):
5758
super().__init__(fiber=fiber)
5859
self.name = name
@@ -65,7 +66,7 @@ def __init__(
6566
self.ideal_batch_size: int = ideal_batch_size
6667
self.page_seq_stride = self.model_params.paged_kv_cache.block_seq_stride
6768
self.scheduler = Scheduler(ideal_batch_size=self.ideal_batch_size)
68-
69+
self.exec_fiber = exec_fiber
6970
self.program_isolation = program_isolation
7071

7172
def handle_inference_request(self, request):
@@ -161,6 +162,7 @@ def __init__(
161162
model_params: ModelParams,
162163
prefill_functions: dict[int, sf.ProgramFunction],
163164
program_isolation: str,
165+
exec_fiber: Fiber,
164166
):
165167
super().__init__(
166168
name="prefill",
@@ -170,11 +172,12 @@ def __init__(
170172
functions=prefill_functions,
171173
ideal_batch_size=max(model_params.prefill_batch_sizes),
172174
program_isolation=program_isolation,
175+
exec_fiber=exec_fiber,
173176
)
174177

175178
def make_process(self, cache: BasePagedAttentionCache, fiber: Fiber):
176179
return PrefillExecutorProcess(
177-
fiber,
180+
self.exec_fiber,
178181
self.functions,
179182
self.page_seq_stride,
180183
cache.page_pool.page_tables,
@@ -216,6 +219,7 @@ def __init__(
216219
model_params: ModelParams,
217220
decode_functions: dict[int, sf.ProgramFunction],
218221
program_isolation: str,
222+
exec_fiber: Fiber,
219223
):
220224
super().__init__(
221225
name="decode",
@@ -225,11 +229,12 @@ def __init__(
225229
functions=decode_functions,
226230
ideal_batch_size=max(model_params.decode_batch_sizes),
227231
program_isolation=program_isolation,
232+
exec_fiber=exec_fiber,
228233
)
229234

230235
def make_process(self, cache: BasePagedAttentionCache, fiber: Fiber):
231236
return DecodeExecutorProcess(
232-
fiber,
237+
self.exec_fiber,
233238
self.functions,
234239
self.page_seq_stride,
235240
cache.page_pool.page_tables,

shortfin/python/shortfin_apps/llm/components/config_struct.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ class ServerParams:
238238
amdgpu_async_caching: bool = False
239239
amdgpu_allocators: Optional[str] = None
240240
amdgpu_allow_device_reuse: bool = False
241+
disaggregate: bool = False
241242

242243
@staticmethod
243244
def load(config_path: Optional[Path] = None) -> "ServerParams":

shortfin/python/shortfin_apps/llm/components/fiber_pool.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,16 @@ async def get(self) -> tuple[int, sf.Fiber]:
5252
except asyncio.QueueEmpty:
5353
if self.resizable:
5454
# Resize the fiber pool by adding a new fiber.
55+
devices = self.sysman.ls.devices
56+
num_devices = len(devices)
5557
new_worker = self.sysman.ls.create_worker(
5658
f"{self.name}-new-worker-{self.__extra_fibers}"
5759
)
5860
self.__workers.append(new_worker)
5961

60-
fiber = self.sysman.ls.create_fiber(new_worker)
62+
fiber = self.sysman.ls.create_fiber(
63+
new_worker, devices=[devices[self.size() % num_devices]]
64+
)
6165
self.__fiber_pool.append(fiber)
6266
self.__extra_fibers += 1
6367
return [self.size() - 1, fiber]
@@ -69,11 +73,15 @@ def pool(self) -> list[sf.Fiber]:
6973
return self.__fiber_pool
7074

7175
def __initialize_pool(self):
76+
devices = self.sysman.ls.devices
77+
num_devices = len(devices)
7278
for idx in range(self.init_size):
7379
worker = self.sysman.ls.create_worker(f"{self.name}-init-worker-{idx}")
7480
self.__workers.append(worker)
7581

76-
fiber = self.sysman.ls.create_fiber(worker)
82+
fiber = self.sysman.ls.create_fiber(
83+
worker, devices=[devices[idx % num_devices]]
84+
)
7785
self.__fiber_pool.append(fiber)
7886
assert idx < self.size()
7987
self.__index_queue.put_nowait(idx)

shortfin/python/shortfin_apps/llm/components/lifecycle.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def lifecycle(app: FastApi):
2525

2626
from contextlib import asynccontextmanager
2727
import logging
28+
import os
2829

2930

3031
def get_eos_from_tokenizer_config(json_path):
@@ -63,6 +64,11 @@ def __init__(self, args):
6364
)
6465
server_params.decode_config = decode_config
6566

67+
if args.disaggregate:
68+
# Setup two logical devices on one physical device to disaggregate
69+
# prefill and decode invocations to distinct streams.
70+
os.environ["SHORTFIN_AMDGPU_LOGICAL_DEVICES_PER_PHYSICAL_DEVICE"] = "2"
71+
6672
# Setup system (configure devices, etc).
6773
sysman = LlmSystemManager(
6874
device=args.device,

shortfin/python/shortfin_apps/llm/components/service.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
class LlmGenerateService(GenerateService):
3434
"""Top level service interface for generating text against a model."""
3535

36-
inference_program: sf.Program
36+
inference_program: list[sf.Program]
3737
prefill_functions: dict[int, sf.ProgramFunction]
3838
decode_functions: dict[int, sf.ProgramFunction]
3939

@@ -53,6 +53,7 @@ def __init__(
5353
self.tokenizer = tokenizer
5454
self.model_params = model_params
5555
self.server_params = server_params
56+
self.disaggregate = server_params.disaggregate
5657
self.max_queue_size = max_queue_size
5758
self.current_queue_size = 0
5859
self.main_fiber_pool = FiberPool(
@@ -92,23 +93,30 @@ def remove_from_queue(self, num_beams: int):
9293
def _initialize_worker_and_fiber(self):
9394
num_workers = self.server_params.workers
9495
fibers_per_worker = self.server_params.fibers_per_worker
96+
devices = self.sysman.ls.devices
9597

9698
logger.info(
9799
f"Creating {num_workers} workers, with {fibers_per_worker} fibers per worker..."
98100
)
99101

100102
self.main_worker = self.sysman.ls.create_worker(f"{self.name}-inference-main-0")
101-
self.main_fiber = self.sysman.ls.create_fiber(self.main_worker)
103+
self.main_fiber = self.sysman.ls.create_fiber(
104+
self.main_worker, devices=[devices[0]]
105+
)
102106

103107
self.prefill_worker = self.sysman.ls.create_worker(
104108
f"{self.name}-inference-prefill-0"
105109
)
106-
self.prefill_fiber = self.sysman.ls.create_fiber(self.prefill_worker)
110+
self.prefill_fiber = self.sysman.ls.create_fiber(
111+
self.prefill_worker, devices=[devices[0]]
112+
)
107113

108114
self.decode_worker = self.sysman.ls.create_worker(
109115
f"{self.name}-inference-decode-0"
110116
)
111-
self.decode_fiber = self.sysman.ls.create_fiber(self.decode_worker)
117+
self.decode_fiber = self.sysman.ls.create_fiber(
118+
self.decode_worker, devices=[devices[1 % len(devices)]]
119+
)
112120

113121
self.devices = self.prefill_fiber.devices_dict.values()
114122

@@ -141,17 +149,36 @@ def _initialize_page_cache(self):
141149

142150
def start(self):
143151
component_modules = self.initialize_program_modules("main")
144-
self.inference_program = self.create_program(
145-
modules=component_modules, devices=self.sysman.ls.devices
146-
)
152+
print(f"{self.disaggregate=}")
153+
self.inference_program = [
154+
self.create_program(
155+
modules=component_modules, devices=[self.sysman.ls.devices[idx]]
156+
)
157+
for idx in range(len(self.sysman.ls.devices))
158+
]
147159
self.initialize_function_references()
148160

161+
task_list = [
162+
"prefill-exec",
163+
"decode-exec",
164+
]
165+
166+
devices = self.sysman.ls.devices
167+
workers = [self.sysman.ls.create_worker(f"{task}-worker") for task in task_list]
168+
fibers = [
169+
self.sysman.ls.create_fiber(
170+
workers[idx], devices=[devices[idx % len(devices)]]
171+
)
172+
for idx in range(len(workers))
173+
]
174+
149175
self.prefill_batcher = PrefillBatcherProcess(
150176
self.prefill_fiber,
151177
self.page_cache,
152178
self.model_params,
153179
self.prefill_functions,
154180
self.prog_isolation,
181+
fibers[0],
155182
)
156183

157184
self.decode_batcher = DecodeBatcherProcess(
@@ -160,21 +187,24 @@ def start(self):
160187
self.model_params,
161188
self.decode_functions,
162189
self.prog_isolation,
190+
fibers[1],
163191
)
164192

165193
self.prefill_batcher.launch()
166194
self.decode_batcher.launch()
167195

168196
def initialize_function_references(self):
197+
devices = self.sysman.ls.devices
198+
num_devices = len(devices)
169199
self.prefill_functions = {}
170200
for bs in self.model_params.prefill_batch_sizes:
171-
self.prefill_functions[bs] = self.inference_program[
201+
self.prefill_functions[bs] = self.inference_program[0][
172202
f"{self.model_params.module_name}.prefill_bs{bs}"
173203
]
174204
# Resolve decode entrypoints.
175205
self.decode_functions = {}
176206
for bs in self.model_params.decode_batch_sizes:
177-
self.decode_functions[bs] = self.inference_program[
207+
self.decode_functions[bs] = self.inference_program[1 % num_devices][
178208
f"{self.model_params.module_name}.decode_bs{bs}"
179209
]
180210

0 commit comments

Comments
 (0)