Skip to content

[Shortfin][LLM] Add initial support for disaggregated invocations #1463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions shortfin/python/shortfin_apps/llm/components/batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
functions: dict[int, sf.ProgramFunction],
ideal_batch_size: int,
program_isolation: str,
exec_fiber: Fiber | None = None,
):
super().__init__(fiber=fiber)
self.name = name
Expand All @@ -67,8 +68,8 @@ def __init__(
self.page_seq_stride = self.model_params.paged_kv_cache.block_seq_stride
self.scheduler = Scheduler(ideal_batch_size=self.ideal_batch_size)
self.cache = DeviceArrayCache(fiber.device(0))

self.program_isolation = program_isolation
self.exec_fiber = exec_fiber

def handle_inference_request(self, request):
"""Handle an inference request."""
Expand Down Expand Up @@ -120,7 +121,11 @@ async def board_flights(self):
scheduled = []
for job in to_schedule:
scheduled = scheduled + job
self.board(cache, self.fiber, job)
self.board(
cache,
self.exec_fiber if self.exec_fiber is not None else self.fiber,
job,
)
logger.debug("Post boarding cache state: %r", cache)

pending = set(pending) - set(scheduled)
Expand Down Expand Up @@ -168,6 +173,7 @@ def __init__(
model_params: ModelParams,
prefill_functions: dict[int, sf.ProgramFunction],
program_isolation: str,
exec_fiber: Fiber | None = None,
):
super().__init__(
name="prefill",
Expand All @@ -177,6 +183,7 @@ def __init__(
functions=prefill_functions,
ideal_batch_size=max(model_params.prefill_batch_sizes),
program_isolation=program_isolation,
exec_fiber=exec_fiber,
)

def make_process(self, cache: BasePagedAttentionCache, fiber: Fiber):
Expand Down Expand Up @@ -224,6 +231,7 @@ def __init__(
model_params: ModelParams,
decode_functions: dict[int, sf.ProgramFunction],
program_isolation: str,
exec_fiber: Fiber | None = None,
):
super().__init__(
name="decode",
Expand All @@ -233,6 +241,7 @@ def __init__(
functions=decode_functions,
ideal_batch_size=max(model_params.decode_batch_sizes),
program_isolation=program_isolation,
exec_fiber=exec_fiber,
)

def make_process(self, cache: BasePagedAttentionCache, fiber: Fiber):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ class ServerParams:
amdgpu_async_caching: bool = False
amdgpu_allocators: Optional[str] = None
amdgpu_allow_device_reuse: bool = False
disaggregate: bool = False

@staticmethod
def load(config_path: Optional[Path] = None) -> "ServerParams":
Expand Down
113 changes: 78 additions & 35 deletions shortfin/python/shortfin_apps/llm/components/fiber_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,64 +33,107 @@ def __init__(
self.sysman: LlmSystemManager = sysman
self.name: str = name

# Name mangle to make outside access harder.
self.__fiber_pool: list[sf.Fiber] = []
self.__workers: list[sf.Worker] = []
self._fiber_pool: list[sf.Fiber] = []
self._workers: list[sf.Worker] = []
# Keep track of how many extra fibers were created
# during runtime if `resizable` is set to True.
self.__extra_fibers: int = 0
self.__index_queue = asyncio.Queue()
# Any code that modifies the index_queue or the fiber_pool
# needs to be locked. asyncio.Queue is not thread-safe, so
# this is required to avoid issues like new fibers with the
# same name as existing ones.
self.__lock = Lock()
self.__initialize_pool()
self._extra_fibers: int = 0
self._index_queue = asyncio.Queue()
self._lock = Lock()
self._initialize_pool()

def resize(self):
new_worker = self.sysman.ls.create_worker(
f"{self.name}-new-worker-{self._extra_fibers}"
)
self._workers.append(new_worker)
fiber = self.sysman.ls.create_fiber(new_worker)
self._fiber_pool.append(fiber)
self._extra_fibers += 1

return [self.size() - 1, fiber]

async def get(self) -> tuple[int, sf.Fiber]:
with self.__lock:
with self._lock:
try:
idx = self.__index_queue.get_nowait()
idx = self._index_queue.get_nowait()
return (
idx,
self.__fiber_pool[idx],
self._fiber_pool[idx],
)
except asyncio.QueueEmpty:
if self.resizable:
# Resize the fiber pool by adding a new fiber.
new_worker = self.sysman.ls.create_worker(
f"{self.name}-new-worker-{self.__extra_fibers}"
)
self.__workers.append(new_worker)

fiber = self.sysman.ls.create_fiber(new_worker)
self.__fiber_pool.append(fiber)
self.__extra_fibers += 1
return [self.size() - 1, fiber]
return self.resize()

available_index = await self.__index_queue.get()
return (available_index, self.__fiber_pool[available_index])
available_index = await self._index_queue.get()
return (available_index, self._fiber_pool[available_index])

def pool(self) -> list[sf.Fiber]:
return self.__fiber_pool
return self._fiber_pool

def __initialize_pool(self):
with self.__lock:
def _initialize_pool(self):
with self._lock:
for idx in range(self.init_size):
worker = self.sysman.ls.create_worker(f"{self.name}-init-worker-{idx}")
self.__workers.append(worker)

self._workers.append(worker)
fiber = self.sysman.ls.create_fiber(worker)
self.__fiber_pool.append(fiber)
self._fiber_pool.append(fiber)
assert idx < self.size()
self.__index_queue.put_nowait(idx)
self._index_queue.put_nowait(idx)

def return_fiber(self, indices: int | list[int]):
with self.__lock:
with self._lock:
if not isinstance(indices, list):
indices = [indices]
for idx in indices:
self.__index_queue.put_nowait(idx)
self._index_queue.put_nowait(idx)

def size(self) -> int:
return len(self.__fiber_pool)
return len(self._fiber_pool)


class DisaggregatedFiberPool(FiberPool):
def __init__(
self,
sysman: LlmSystemManager,
init_size: int,
resizable: bool = True,
name: str = "default-disagg-fiber-pool",
):
super().__init__(
sysman=sysman,
init_size=init_size,
resizable=resizable,
name=name,
)

def resize(self):
devices = self.sysman.ls.devices
num_devices = len(devices)
new_worker = self.sysman.ls.create_worker(
f"{self.name}-new-worker-{self._extra_fibers}"
)
self._workers.append(new_worker)

fiber = self.sysman.ls.create_fiber(
new_worker, devices=[devices[self.size() % num_devices]]
)
self._fiber_pool.append(fiber)
self._extra_fibers += 1
return [self.size() - 1, fiber]

def _initialize_pool(self):
with self._lock:
devices = self.sysman.ls.devices
num_devices = len(devices)
for idx in range(self.init_size):
worker = self.sysman.ls.create_worker(f"{self.name}-init-worker-{idx}")
self._workers.append(worker)

fiber = self.sysman.ls.create_fiber(
worker, devices=[devices[idx % num_devices]]
)
self._fiber_pool.append(fiber)
assert idx < self.size()
self._index_queue.put_nowait(idx)
32 changes: 22 additions & 10 deletions shortfin/python/shortfin_apps/llm/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,23 +88,28 @@ def __init__(
)
self.streamed_tokens_index = 0
self._status_tracker = status_tracker

async def run(self):
exec_req = LlmInferenceExecRequest(
self.exec_req = LlmInferenceExecRequest(
phase=InferencePhase.PREFILL,
input_token_ids=self.input_token_ids,
rid=self.gen_req.rid,
status_tracker=self._status_tracker,
)
exec_req._cache = self.client.prefill_batcher.page_cache

async def run(self):
self.exec_req._cache = self.client.prefill_batcher.page_cache
try:
# Prefill result.
await self.token_selector.prefill(exec_req)
await self.token_selector.prefill(self.exec_req)

# Decode loop.
await self.token_selector.decode(exec_req)
await self.token_selector.decode(self.exec_req)
finally:
exec_req.free_cache_pages()
self.exec_req.request_exec_success.set_success()
self.exec_req.free_cache_pages()

async def await_completion(self):
await self.exec_req.request_exec_success
return self.index

def results_callback(self, result: int | list[list[int]]):
if is_multi_response(self.decode_config):
Expand Down Expand Up @@ -225,6 +230,7 @@ async def run(self):
else:
input_batch = self.tokenize()

pending = []
for index, input_tokens in enumerate(input_batch):
decode_config = copy(self.decode_config)
decode_config.update_from_sampling_params(
Expand Down Expand Up @@ -273,11 +279,17 @@ async def run(self):
fiber=fiber,
)
gen_processes.append(gen_process)
pending.append(asyncio.create_task(gen_process.await_completion()))
gen_process.launch()

await asyncio.gather(*gen_processes)
if not self.responder.is_disconnected():
self.generate_response(gen_processes, streaming)
while pending:
done, pending = await asyncio.wait(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to construct multiple awaitables then perform a gather and await. Seeing a loop on pending needlessly uses the python interpreter manage instead of relying on asyncio features.

pending, return_when=asyncio.FIRST_COMPLETED
)
for task in done:
idx = await task
if not self.responder.is_disconnected():
self.generate_response([gen_processes[idx]], streaming)
finally:
# Remove request from queue when done
self.service.remove_from_queue(self.decode_config.num_beams)
Expand Down
18 changes: 16 additions & 2 deletions shortfin/python/shortfin_apps/llm/components/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@ def lifecycle(app: FastApi):
from .config_struct import ModelParams, ServerParams
from .token_selection_strategy import DecodeConfig
from .manager import LlmSystemManager
from .service import LlmGenerateService
from .service import LlmGenerateService, LlmGenerateDisaggregatedService
from .tokenizer import Tokenizer
from typing import TYPE_CHECKING
from fastapi import FastAPI


from contextlib import asynccontextmanager
import logging
import os


def get_eos_from_tokenizer_config(json_path):
Expand Down Expand Up @@ -63,6 +64,19 @@ def __init__(self, args):
)
server_params.decode_config = decode_config

service_cls = LlmGenerateService
if args.disaggregate:
# To not run into complications with sharded models, assert that the server is
# being run only on one physical device.
rocr_visible_devices = os.environ.get("ROCR_VISIBLE_DEVICES")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devices can be set with --device_ids, without having to set ROCR_VISIBLE_DEVICES

assert (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still needs to be fixed. I think it'd be better to check device_ids, instead of ROCR_VISIBLE_DEVICES.

I don't set ROCR_VISIBLE_DEVICES when running the server, unless someone set the system to something like SPX/DPX/etc., and just specify devices with --device_ids.

User could also set ROCR_VISIBLE_DEVICES to multiple devices, but run the shortfin server with only one devices --device_ids 0.

We should be reading --device_ids to determine the number of devices the user is attempting to select.

And if --device_ids is None, then we read ROCR_VISIBLE_DEVICES

rocr_visible_devices is not None and len(rocr_visible_devices) <= 2
), "Running disaggregated prefill on HIP streams is supported only when running on one physical device. Set `ROCR_VISIBLE_DEVICES`=<device_id>."
# Setup two logical devices on one physical device to disaggregate
# prefill and decode invocations to distinct streams.
os.environ["SHORTFIN_AMDGPU_LOGICAL_DEVICES_PER_PHYSICAL_DEVICE"] = "2"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should not be using environment variables to plumb the logical devices. Look into the shortfin native library for where you can specify this variable and pass it through.

service_cls = LlmGenerateDisaggregatedService

# Setup system (configure devices, etc).
sysman = LlmSystemManager(
device=args.device,
Expand All @@ -78,7 +92,7 @@ def __init__(self, args):
tokenizer = Tokenizer.from_tokenizer_json_file(
args.tokenizer_json, eos_token=eos_token
)
service = LlmGenerateService(
service = service_cls(
name="default",
sysman=sysman,
tokenizer=tokenizer,
Expand Down
3 changes: 3 additions & 0 deletions shortfin/python/shortfin_apps/llm/components/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def __init__(
self.input_token_ids = input_token_ids
self.prompt_length = len(input_token_ids)
self.done = sf.VoidFuture()
# This is set to success once the request has been decoded successfully
# and response is ready to be sent back to the client.
self.request_exec_success = sf.VoidFuture()
self.rid = rid
# Unique `instance_id` for token selection strategies that may need
# to differentiate between an original req and a copy of a req.
Expand Down
Loading
Loading