Skip to content

Commit 9214f22

Browse files
committed
Assert single physical GPU; Cleanup
1 parent 63f6238 commit 9214f22

File tree

4 files changed

+202
-56
lines changed

4 files changed

+202
-56
lines changed

shortfin/python/shortfin_apps/llm/components/batcher.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
functions: dict[int, sf.ProgramFunction],
5555
ideal_batch_size: int,
5656
program_isolation: str,
57-
exec_fiber: Fiber,
57+
exec_fiber: Fiber | None = None,
5858
):
5959
super().__init__(fiber=fiber)
6060
self.name = name
@@ -121,7 +121,11 @@ async def board_flights(self):
121121
scheduled = []
122122
for job in to_schedule:
123123
scheduled = scheduled + job
124-
self.board(cache, self.exec_fiber, job)
124+
self.board(
125+
cache,
126+
self.exec_fiber if self.exec_fiber is not None else self.fiber,
127+
job,
128+
)
125129
logger.debug("Post boarding cache state: %r", cache)
126130

127131
pending = set(pending) - set(scheduled)
@@ -169,7 +173,7 @@ def __init__(
169173
model_params: ModelParams,
170174
prefill_functions: dict[int, sf.ProgramFunction],
171175
program_isolation: str,
172-
exec_fiber: Fiber,
176+
exec_fiber: Fiber | None = None,
173177
):
174178
super().__init__(
175179
name="prefill",
@@ -227,7 +231,7 @@ def __init__(
227231
model_params: ModelParams,
228232
decode_functions: dict[int, sf.ProgramFunction],
229233
program_isolation: str,
230-
exec_fiber: Fiber,
234+
exec_fiber: Fiber | None = None,
231235
):
232236
super().__init__(
233237
name="decode",

shortfin/python/shortfin_apps/llm/components/fiber_pool.py

Lines changed: 75 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -32,65 +32,103 @@ def __init__(
3232
self.sysman: LlmSystemManager = sysman
3333
self.name: str = name
3434

35-
# Name mangle to make outside access harder.
36-
self.__fiber_pool: list[sf.Fiber] = []
37-
self.__workers: list[sf.Worker] = []
35+
self._fiber_pool: list[sf.Fiber] = []
36+
self._workers: list[sf.Worker] = []
3837
# Keep track of how many extra fibers were created
3938
# during runtime if `resizable` is set to True.
40-
self.__extra_fibers: int = 0
41-
self.__index_queue = asyncio.Queue()
39+
self._extra_fibers: int = 0
40+
self._index_queue = asyncio.Queue()
4241

43-
self.__initialize_pool()
42+
self._initialize_pool()
43+
44+
def resize(self):
45+
new_worker = self.sysman.ls.create_worker(
46+
f"{self.name}-new-worker-{self._extra_fibers}"
47+
)
48+
self._workers.append(new_worker)
49+
fiber = self.sysman.ls.create_fiber(new_worker)
50+
self._fiber_pool.append(fiber)
51+
self._extra_fibers += 1
52+
53+
return [self.size() - 1, fiber]
4454

4555
async def get(self) -> tuple[int, sf.Fiber]:
4656
try:
47-
idx = self.__index_queue.get_nowait()
57+
idx = self._index_queue.get_nowait()
4858
return (
4959
idx,
50-
self.__fiber_pool[idx],
60+
self._fiber_pool[idx],
5161
)
5262
except asyncio.QueueEmpty:
5363
if self.resizable:
5464
# Resize the fiber pool by adding a new fiber.
55-
devices = self.sysman.ls.devices
56-
num_devices = len(devices)
57-
new_worker = self.sysman.ls.create_worker(
58-
f"{self.name}-new-worker-{self.__extra_fibers}"
59-
)
60-
self.__workers.append(new_worker)
61-
62-
fiber = self.sysman.ls.create_fiber(
63-
new_worker, devices=[devices[self.size() % num_devices]]
64-
)
65-
self.__fiber_pool.append(fiber)
66-
self.__extra_fibers += 1
67-
return [self.size() - 1, fiber]
68-
69-
available_index = await self.__index_queue.get()
70-
return (available_index, self.__fiber_pool[available_index])
65+
return self.resize()
66+
67+
available_index = await self._index_queue.get()
68+
return (available_index, self._fiber_pool[available_index])
7169

7270
def pool(self) -> list[sf.Fiber]:
73-
return self.__fiber_pool
71+
return self._fiber_pool
7472

75-
def __initialize_pool(self):
76-
devices = self.sysman.ls.devices
77-
num_devices = len(devices)
73+
def _initialize_pool(self):
7874
for idx in range(self.init_size):
7975
worker = self.sysman.ls.create_worker(f"{self.name}-init-worker-{idx}")
80-
self.__workers.append(worker)
81-
82-
fiber = self.sysman.ls.create_fiber(
83-
worker, devices=[devices[idx % num_devices]]
84-
)
85-
self.__fiber_pool.append(fiber)
76+
self._workers.append(worker)
77+
fiber = self.sysman.ls.create_fiber(worker)
78+
self._fiber_pool.append(fiber)
8679
assert idx < self.size()
87-
self.__index_queue.put_nowait(idx)
80+
self._index_queue.put_nowait(idx)
8881

8982
def return_fiber(self, indices: int | list[int]):
9083
if not isinstance(indices, list):
9184
indices = [indices]
9285
for idx in indices:
93-
self.__index_queue.put_nowait(idx)
86+
self._index_queue.put_nowait(idx)
9487

9588
def size(self) -> int:
96-
return len(self.__fiber_pool)
89+
return len(self._fiber_pool)
90+
91+
92+
class DisaggregatedFiberPool(FiberPool):
93+
def __init__(
94+
self,
95+
sysman: LlmSystemManager,
96+
init_size: int,
97+
resizable: bool = True,
98+
name: str = "default-disagg-fiber-pool",
99+
):
100+
super().__init__(
101+
sysman=sysman,
102+
init_size=init_size,
103+
resizable=resizable,
104+
name=name,
105+
)
106+
107+
def resize(self):
108+
devices = self.sysman.ls.devices
109+
num_devices = len(devices)
110+
new_worker = self.sysman.ls.create_worker(
111+
f"{self.name}-new-worker-{self._extra_fibers}"
112+
)
113+
self._workers.append(new_worker)
114+
115+
fiber = self.sysman.ls.create_fiber(
116+
new_worker, devices=[devices[self.size() % num_devices]]
117+
)
118+
self._fiber_pool.append(fiber)
119+
self._extra_fibers += 1
120+
return [self.size() - 1, fiber]
121+
122+
def _initialize_pool(self):
123+
devices = self.sysman.ls.devices
124+
num_devices = len(devices)
125+
for idx in range(self.init_size):
126+
worker = self.sysman.ls.create_worker(f"{self.name}-init-worker-{idx}")
127+
self._workers.append(worker)
128+
129+
fiber = self.sysman.ls.create_fiber(
130+
worker, devices=[devices[idx % num_devices]]
131+
)
132+
self._fiber_pool.append(fiber)
133+
assert idx < self.size()
134+
self._index_queue.put_nowait(idx)

shortfin/python/shortfin_apps/llm/components/lifecycle.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def lifecycle(app: FastApi):
1717
from .config_struct import ModelParams, ServerParams
1818
from .token_selection_strategy import DecodeConfig
1919
from .manager import LlmSystemManager
20-
from .service import LlmGenerateService
20+
from .service import LlmGenerateService, LlmGenerateDisaggregatedService
2121
from .tokenizer import Tokenizer
2222
from typing import TYPE_CHECKING
2323
from fastapi import FastAPI
@@ -64,10 +64,18 @@ def __init__(self, args):
6464
)
6565
server_params.decode_config = decode_config
6666

67+
service_cls = LlmGenerateService
6768
if args.disaggregate:
69+
# To not run into complications with sharded models, assert that the server is
70+
# being run only on one physical device.
71+
rocr_visible_devices = os.environ.get("ROCR_VISIBLE_DEVICES")
72+
assert (
73+
rocr_visible_devices is not None and len(rocr_visible_devices) <= 2
74+
), "Running disaggregated prefill on HIP streams is supported only when running on one physical device. Set `ROCR_VISIBLE_DEVICES`=<device_id>."
6875
# Setup two logical devices on one physical device to disaggregate
6976
# prefill and decode invocations to distinct streams.
7077
os.environ["SHORTFIN_AMDGPU_LOGICAL_DEVICES_PER_PHYSICAL_DEVICE"] = "2"
78+
service_cls = LlmGenerateDisaggregatedService
7179

7280
# Setup system (configure devices, etc).
7381
sysman = LlmSystemManager(
@@ -84,7 +92,7 @@ def __init__(self, args):
8492
tokenizer = Tokenizer.from_tokenizer_json_file(
8593
args.tokenizer_json, eos_token=eos_token
8694
)
87-
service = LlmGenerateService(
95+
service = service_cls(
8896
name="default",
8997
sysman=sysman,
9098
tokenizer=tokenizer,

shortfin/python/shortfin_apps/llm/components/service.py

Lines changed: 109 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from .token_selection_strategy import is_multi_response
2626

2727
from ...utils import GenerateService
28-
from .fiber_pool import FiberPool
28+
from .fiber_pool import FiberPool, DisaggregatedFiberPool
2929

3030
logger = logging.getLogger(__name__)
3131

@@ -56,7 +56,8 @@ def __init__(
5656
self.disaggregate = server_params.disaggregate
5757
self.max_queue_size = max_queue_size
5858
self.current_queue_size = 0
59-
self.main_fiber_pool = FiberPool(
59+
fiber_pool_cls = DisaggregatedFiberPool if self.disaggregate else FiberPool
60+
self.main_fiber_pool = fiber_pool_cls(
6061
self.sysman, self.max_queue_size, resizable=True
6162
)
6263

@@ -93,30 +94,23 @@ def remove_from_queue(self, num_beams: int):
9394
def _initialize_worker_and_fiber(self):
9495
num_workers = self.server_params.workers
9596
fibers_per_worker = self.server_params.fibers_per_worker
96-
devices = self.sysman.ls.devices
9797

9898
logger.info(
9999
f"Creating {num_workers} workers, with {fibers_per_worker} fibers per worker..."
100100
)
101101

102102
self.main_worker = self.sysman.ls.create_worker(f"{self.name}-inference-main-0")
103-
self.main_fiber = self.sysman.ls.create_fiber(
104-
self.main_worker, devices=[devices[0]]
105-
)
103+
self.main_fiber = self.sysman.ls.create_fiber(self.main_worker)
106104

107105
self.prefill_worker = self.sysman.ls.create_worker(
108106
f"{self.name}-inference-prefill-0"
109107
)
110-
self.prefill_fiber = self.sysman.ls.create_fiber(
111-
self.prefill_worker, devices=[devices[0]]
112-
)
108+
self.prefill_fiber = self.sysman.ls.create_fiber(self.prefill_worker)
113109

114110
self.decode_worker = self.sysman.ls.create_worker(
115111
f"{self.name}-inference-decode-0"
116112
)
117-
self.decode_fiber = self.sysman.ls.create_fiber(
118-
self.decode_worker, devices=[devices[1 % len(devices)]]
119-
)
113+
self.decode_fiber = self.sysman.ls.create_fiber(self.decode_worker)
120114

121115
self.devices = self.prefill_fiber.devices_dict.values()
122116

@@ -147,6 +141,108 @@ def _initialize_page_cache(self):
147141
f"Unknown prefix_sharing_algorithm {self.server_params.prefix_sharing_algorithm}. Currently only supporting 'trie' and 'none'."
148142
)
149143

144+
def start(self):
145+
component_modules = self.initialize_program_modules("main")
146+
self.inference_program = self.create_program(
147+
modules=component_modules, devices=self.sysman.ls.devices
148+
)
149+
self.initialize_function_references()
150+
151+
self.prefill_batcher = PrefillBatcherProcess(
152+
self.prefill_fiber,
153+
self.page_cache,
154+
self.model_params,
155+
self.prefill_functions,
156+
self.prog_isolation,
157+
)
158+
159+
self.decode_batcher = DecodeBatcherProcess(
160+
self.decode_fiber,
161+
self.page_cache,
162+
self.model_params,
163+
self.decode_functions,
164+
self.prog_isolation,
165+
)
166+
167+
self.prefill_batcher.launch()
168+
self.decode_batcher.launch()
169+
170+
def initialize_function_references(self):
171+
self.prefill_functions = {}
172+
for bs in self.model_params.prefill_batch_sizes:
173+
self.prefill_functions[bs] = self.inference_program[
174+
f"{self.model_params.module_name}.prefill_bs{bs}"
175+
]
176+
# Resolve decode entrypoints.
177+
self.decode_functions = {}
178+
for bs in self.model_params.decode_batch_sizes:
179+
self.decode_functions[bs] = self.inference_program[
180+
f"{self.model_params.module_name}.decode_bs{bs}"
181+
]
182+
183+
def __repr__(self):
184+
return (
185+
f"ServiceManager(\n"
186+
f" model_params={self.model_params}\n"
187+
f" server_params={self.server_params}\n"
188+
f" inference_modules={self.inference_modules}\n"
189+
f" page_cache={self.page_cache}\n"
190+
f")"
191+
)
192+
193+
194+
class LlmGenerateDisaggregatedService(LlmGenerateService):
195+
def __init__(
196+
self,
197+
*,
198+
name: str,
199+
sysman: LlmSystemManager,
200+
tokenizer: Tokenizer,
201+
model_params: ModelParams,
202+
server_params: "ServerParams",
203+
program_isolation: str = "per_call",
204+
max_queue_size: int = 3, # Maximum number of requests in queue
205+
):
206+
super().__init__(
207+
name=name,
208+
sysman=sysman,
209+
tokenizer=tokenizer,
210+
model_params=model_params,
211+
server_params=server_params,
212+
program_isolation=program_isolation,
213+
max_queue_size=max_queue_size,
214+
)
215+
216+
def _initialize_worker_and_fiber(self):
217+
num_workers = self.server_params.workers
218+
fibers_per_worker = self.server_params.fibers_per_worker
219+
devices = self.sysman.ls.devices
220+
221+
logger.info(
222+
f"Creating {num_workers} workers, with {fibers_per_worker} fibers per worker..."
223+
)
224+
225+
self.main_worker = self.sysman.ls.create_worker(f"{self.name}-inference-main-0")
226+
self.main_fiber = self.sysman.ls.create_fiber(
227+
self.main_worker, devices=[devices[0]]
228+
)
229+
230+
self.prefill_worker = self.sysman.ls.create_worker(
231+
f"{self.name}-inference-prefill-0"
232+
)
233+
self.prefill_fiber = self.sysman.ls.create_fiber(
234+
self.prefill_worker, devices=[devices[0]]
235+
)
236+
237+
self.decode_worker = self.sysman.ls.create_worker(
238+
f"{self.name}-inference-decode-0"
239+
)
240+
self.decode_fiber = self.sysman.ls.create_fiber(
241+
self.decode_worker, devices=[devices[1 % len(devices)]]
242+
)
243+
244+
self.devices = self.prefill_fiber.devices_dict.values()
245+
150246
def start(self):
151247
component_modules = self.initialize_program_modules("main")
152248
print(f"{self.disaggregate=}")
@@ -210,7 +306,7 @@ def initialize_function_references(self):
210306

211307
def __repr__(self):
212308
return (
213-
f"ServiceManager(\n"
309+
f"DisaggregatedServiceManager(\n"
214310
f" model_params={self.model_params}\n"
215311
f" server_params={self.server_params}\n"
216312
f" inference_modules={self.inference_modules}\n"

0 commit comments

Comments
 (0)