Skip to content

Commit 01931fd

Browse files
committed
Assert single physical GPU; Cleanup
1 parent 63f6238 commit 01931fd

File tree

4 files changed

+210
-56
lines changed

4 files changed

+210
-56
lines changed

shortfin/python/shortfin_apps/llm/components/batcher.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
functions: dict[int, sf.ProgramFunction],
5555
ideal_batch_size: int,
5656
program_isolation: str,
57-
exec_fiber: Fiber,
57+
exec_fiber: Fiber | None = None,
5858
):
5959
super().__init__(fiber=fiber)
6060
self.name = name
@@ -121,7 +121,11 @@ async def board_flights(self):
121121
scheduled = []
122122
for job in to_schedule:
123123
scheduled = scheduled + job
124-
self.board(cache, self.exec_fiber, job)
124+
self.board(
125+
cache,
126+
self.exec_fiber if self.exec_fiber is not None else self.fiber,
127+
job,
128+
)
125129
logger.debug("Post boarding cache state: %r", cache)
126130

127131
pending = set(pending) - set(scheduled)
@@ -169,7 +173,7 @@ def __init__(
169173
model_params: ModelParams,
170174
prefill_functions: dict[int, sf.ProgramFunction],
171175
program_isolation: str,
172-
exec_fiber: Fiber,
176+
exec_fiber: Fiber | None = None,
173177
):
174178
super().__init__(
175179
name="prefill",
@@ -227,7 +231,7 @@ def __init__(
227231
model_params: ModelParams,
228232
decode_functions: dict[int, sf.ProgramFunction],
229233
program_isolation: str,
230-
exec_fiber: Fiber,
234+
exec_fiber: Fiber | None = None,
231235
):
232236
super().__init__(
233237
name="decode",

shortfin/python/shortfin_apps/llm/components/fiber_pool.py

Lines changed: 79 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from .manager import LlmSystemManager
1010
import asyncio
1111

12+
from typing import override
13+
1214

1315
class FiberPool:
1416
"""
@@ -33,64 +35,105 @@ def __init__(
3335
self.name: str = name
3436

3537
# Name mangle to make outside access harder.
36-
self.__fiber_pool: list[sf.Fiber] = []
37-
self.__workers: list[sf.Worker] = []
38+
self._fiber_pool: list[sf.Fiber] = []
39+
self._workers: list[sf.Worker] = []
3840
# Keep track of how many extra fibers were created
3941
# during runtime if `resizable` is set to True.
40-
self.__extra_fibers: int = 0
41-
self.__index_queue = asyncio.Queue()
42+
self._extra_fibers: int = 0
43+
self._index_queue = asyncio.Queue()
44+
45+
self._initialize_pool()
4246

43-
self.__initialize_pool()
47+
def resize(self):
48+
new_worker = self.sysman.ls.create_worker(
49+
f"{self.name}-new-worker-{self._extra_fibers}"
50+
)
51+
self._workers.append(new_worker)
52+
fiber = self.sysman.ls.create_fiber(new_worker)
53+
self._fiber_pool.append(fiber)
54+
self._extra_fibers += 1
55+
56+
return [self.size() - 1, fiber]
4457

4558
async def get(self) -> tuple[int, sf.Fiber]:
4659
try:
47-
idx = self.__index_queue.get_nowait()
60+
idx = self._index_queue.get_nowait()
4861
return (
4962
idx,
50-
self.__fiber_pool[idx],
63+
self._fiber_pool[idx],
5164
)
5265
except asyncio.QueueEmpty:
5366
if self.resizable:
5467
# Resize the fiber pool by adding a new fiber.
55-
devices = self.sysman.ls.devices
56-
num_devices = len(devices)
57-
new_worker = self.sysman.ls.create_worker(
58-
f"{self.name}-new-worker-{self.__extra_fibers}"
59-
)
60-
self.__workers.append(new_worker)
61-
62-
fiber = self.sysman.ls.create_fiber(
63-
new_worker, devices=[devices[self.size() % num_devices]]
64-
)
65-
self.__fiber_pool.append(fiber)
66-
self.__extra_fibers += 1
67-
return [self.size() - 1, fiber]
68-
69-
available_index = await self.__index_queue.get()
70-
return (available_index, self.__fiber_pool[available_index])
68+
return self.resize()
69+
70+
available_index = await self._index_queue.get()
71+
return (available_index, self._fiber_pool[available_index])
7172

7273
def pool(self) -> list[sf.Fiber]:
73-
return self.__fiber_pool
74+
return self._fiber_pool
7475

75-
def __initialize_pool(self):
76-
devices = self.sysman.ls.devices
77-
num_devices = len(devices)
76+
def _initialize_pool(self):
7877
for idx in range(self.init_size):
7978
worker = self.sysman.ls.create_worker(f"{self.name}-init-worker-{idx}")
80-
self.__workers.append(worker)
81-
82-
fiber = self.sysman.ls.create_fiber(
83-
worker, devices=[devices[idx % num_devices]]
84-
)
85-
self.__fiber_pool.append(fiber)
79+
self._workers.append(worker)
80+
fiber = self.sysman.ls.create_fiber(worker)
81+
self._fiber_pool.append(fiber)
8682
assert idx < self.size()
87-
self.__index_queue.put_nowait(idx)
83+
self._index_queue.put_nowait(idx)
8884

8985
def return_fiber(self, indices: int | list[int]):
9086
if not isinstance(indices, list):
9187
indices = [indices]
9288
for idx in indices:
93-
self.__index_queue.put_nowait(idx)
89+
self._index_queue.put_nowait(idx)
9490

9591
def size(self) -> int:
96-
return len(self.__fiber_pool)
92+
return len(self._fiber_pool)
93+
94+
95+
class DisaggregatedFiberPool(FiberPool):
96+
def __init__(
97+
self,
98+
sysman: LlmSystemManager,
99+
init_size: int,
100+
resizable: bool = True,
101+
name: str = "default-disagg-fiber-pool",
102+
):
103+
super().__init__(
104+
sysman=sysman,
105+
init_size=init_size,
106+
resizable=resizable,
107+
name=name,
108+
)
109+
110+
@override
111+
def resize(self):
112+
devices = self.sysman.ls.devices
113+
num_devices = len(devices)
114+
new_worker = self.sysman.ls.create_worker(
115+
f"{self.name}-new-worker-{self._extra_fibers}"
116+
)
117+
self._workers.append(new_worker)
118+
119+
fiber = self.sysman.ls.create_fiber(
120+
new_worker, devices=[devices[self.size() % num_devices]]
121+
)
122+
self._fiber_pool.append(fiber)
123+
self._extra_fibers += 1
124+
return [self.size() - 1, fiber]
125+
126+
@override
127+
def _initialize_pool(self):
128+
devices = self.sysman.ls.devices
129+
num_devices = len(devices)
130+
for idx in range(self.init_size):
131+
worker = self.sysman.ls.create_worker(f"{self.name}-init-worker-{idx}")
132+
self._workers.append(worker)
133+
134+
fiber = self.sysman.ls.create_fiber(
135+
worker, devices=[devices[idx % num_devices]]
136+
)
137+
self._fiber_pool.append(fiber)
138+
assert idx < self.size()
139+
self._index_queue.put_nowait(idx)

shortfin/python/shortfin_apps/llm/components/lifecycle.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def lifecycle(app: FastApi):
1717
from .config_struct import ModelParams, ServerParams
1818
from .token_selection_strategy import DecodeConfig
1919
from .manager import LlmSystemManager
20-
from .service import LlmGenerateService
20+
from .service import LlmGenerateService, LlmGenerateDisaggregatedService
2121
from .tokenizer import Tokenizer
2222
from typing import TYPE_CHECKING
2323
from fastapi import FastAPI
@@ -64,10 +64,18 @@ def __init__(self, args):
6464
)
6565
server_params.decode_config = decode_config
6666

67+
service_cls = LlmGenerateService
6768
if args.disaggregate:
69+
# To not run into complications with sharded models, assert that the server is
70+
# being run only on one physical device.
71+
rocr_visible_devices = os.environ.get("ROCR_VISIBLE_DEVICES")
72+
assert (
73+
rocr_visible_devices is not None and len(rocr_visible_devices) <= 2
74+
), "Running disaggregated prefill on HIP streams is supported only when running on one physical device. Set `ROCR_VISIBLE_DEVICES`=<device_id>."
6875
# Setup two logical devices on one physical device to disaggregate
6976
# prefill and decode invocations to distinct streams.
7077
os.environ["SHORTFIN_AMDGPU_LOGICAL_DEVICES_PER_PHYSICAL_DEVICE"] = "2"
78+
service_cls = LlmGenerateDisaggregatedService
7179

7280
# Setup system (configure devices, etc).
7381
sysman = LlmSystemManager(
@@ -84,7 +92,7 @@ def __init__(self, args):
8492
tokenizer = Tokenizer.from_tokenizer_json_file(
8593
args.tokenizer_json, eos_token=eos_token
8694
)
87-
service = LlmGenerateService(
95+
service = service_cls(
8896
name="default",
8997
sysman=sysman,
9098
tokenizer=tokenizer,

0 commit comments

Comments
 (0)