Skip to content

Use multiple logical devices to handle ClientGenerateBatchProcess #1403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions shortfin/python/shortfin_apps/llm/components/batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(
self.ideal_batch_size: int = ideal_batch_size
self.page_seq_stride = self.model_params.paged_kv_cache.block_seq_stride
self._current_workitems = 0
self.worker_index = 0

self.fiber_pool = fiber_pool
self.program_isolation = program_isolation
Expand Down Expand Up @@ -217,6 +218,7 @@ def make_process(self, cache: BasePagedAttentionCache, fiber: Fiber):
cache.page_pool.page_tables,
self.fiber_pool,
self.program_isolation,
self.worker_index,
)

def board_request(self, cache, request: LlmInferenceExecRequest):
Expand Down Expand Up @@ -273,6 +275,7 @@ def make_process(self, cache: BasePagedAttentionCache, fiber: Fiber):
cache.page_pool.page_tables,
self.fiber_pool,
self.program_isolation,
self.worker_index,
)

def board_request(self, cache, request: LlmInferenceExecRequest):
Expand All @@ -299,6 +302,7 @@ def __init__(
page_tables,
fiber_pool: FiberPool,
program_isolation: sf.ProgramIsolation,
worker_index: int,
):
super().__init__(fiber=fiber)
self.name = name
Expand All @@ -308,18 +312,20 @@ def __init__(
self.functions = functions
self.fiber_pool = fiber_pool
self.program_isolation = program_isolation
self.worker_index = worker_index

async def get_args(self, bs, device0):
async def get_args(self, bs, device_index):
...

async def get_results(self, logits, req_count, device0):
async def get_results(self, logits, req_count, device_index):
...

async def run(self):
try:
req_bs = len(self.exec_requests)
seq_stride = self.seq_stride
device0 = self.fiber.device(0)
current_worker_index = self.worker_index
device0 = self.fiber.device(current_worker_index)
# Select an entrypoint for the batch.
entrypoints = self.functions
for bs, fn in entrypoints.items():
Expand All @@ -328,7 +334,7 @@ async def run(self):
else:
raise RuntimeError(f"No available entry point for bs {req_bs}")

args, req_count = await self.get_args(bs, device0)
args, req_count = await self.get_args(bs, current_worker_index)

logger.info(
"INVOKE %r: %s",
Expand Down Expand Up @@ -366,7 +372,7 @@ async def run(self):
r.publish_allocated_pages(number_of_complete_pages)

# Return results.
await self.get_results(logits, req_count, device0)
await self.get_results(logits, req_count, current_worker_index)

except Exception:
logger.exception("Fatal error in prefetch invocation")
Expand All @@ -388,6 +394,7 @@ def __init__(
page_tables,
fiber_pool: FiberPool,
program_isolation: sf.ProgramIsolation,
worker_index: int,
):
super().__init__(
name="prefill_process",
Expand All @@ -397,9 +404,10 @@ def __init__(
page_tables=page_tables,
fiber_pool=fiber_pool,
program_isolation=program_isolation,
worker_index=worker_index,
)

async def get_args(self, bs, device0):
async def get_args(self, bs, device_index):
seq_stride = self.seq_stride

# Compute block sequence length as maximum sequence length, rounded
Expand All @@ -417,6 +425,7 @@ async def get_args(self, bs, device0):
# TODO: Better support in shortfin for h2d. The best way to do it is
# device dependent.
int_dtype = sfnp.int64
device0 = self.fiber.device(device_index)
tokens = sfnp.device_array.for_device(device0, [bs, bsl], int_dtype)
seq_lens = sfnp.device_array.for_device(device0, [bs], int_dtype)
seq_block_ids = sfnp.device_array.for_device(
Expand Down Expand Up @@ -455,13 +464,14 @@ async def get_args(self, bs, device0):
# seq_block_ids: [bs, blocks]
# cache_slabs: ...
args = [tokens, seq_lens, seq_block_ids]
for page_table in self.page_tables:
args.append(sfnp.disable_barrier(page_table))
page_table = self.page_tables[device_index]
args.append(sfnp.disable_barrier(page_table))

return args, req_count

async def get_results(self, logits, req_count, device0):
# Return results.
async def get_results(self, logits, req_count, device_index):
# Return results
device0 = self.fiber.device(device_index)
for i in range(req_count):
req = self.exec_requests[i]
sl = len(req.input_token_ids)
Expand Down Expand Up @@ -492,6 +502,7 @@ def __init__(
page_tables,
fiber_pool: FiberPool,
isolation: sf.ProgramIsolation,
worker_index: int,
):
super().__init__(
name="decode_process",
Expand All @@ -501,9 +512,10 @@ def __init__(
page_tables=page_tables,
fiber_pool=fiber_pool,
program_isolation=isolation,
worker_index=worker_index,
)

async def get_args(self, bs, device0):
async def get_args(self, bs, device_index):
# Compute block sequence length as maximum sequence length, rounded
# up to the seq_stride.
seq_stride = self.seq_stride
Expand All @@ -517,6 +529,7 @@ async def get_args(self, bs, device0):
# TODO: Better support in shortfin for h2d. The best way to do it is
# device dependent.
int_dtype = sfnp.int64
device0 = self.fiber.device(device_index)
tokens = sfnp.device_array.for_device(device0, [bs, 1], int_dtype)
start_positions = sfnp.device_array.for_device(device0, [bs], int_dtype)
seq_lens = sfnp.device_array.for_device(device0, [bs], int_dtype)
Expand Down Expand Up @@ -576,13 +589,14 @@ async def get_args(self, bs, device0):
# seq_block_ids: [bs, blocks]
# cache_slabs: ...
args = [tokens, seq_lens, start_positions, seq_block_ids]
for page_table in self.page_tables:
args.append(sfnp.disable_barrier(page_table))
page_table = self.page_tables[device_index]
args.append(sfnp.disable_barrier(page_table))

return args, req_count

async def get_results(self, logits, req_count, device0):
async def get_results(self, logits, req_count, device_index):
# Return results.
device0 = self.fiber.device(device_index)
for i in range(req_count):
req = self.exec_requests[i]
sl = 1
Expand Down
11 changes: 11 additions & 0 deletions shortfin/python/shortfin_apps/llm/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def _append_token(self, token: int):


class ClientGenerateBatchProcess(sf.Process):
generate_count = 0
"""Process instantiated for handling a batch from a client.

This takes care of several responsibilities:
Expand Down Expand Up @@ -143,6 +144,7 @@ def __init__(
fiber: sf.Fiber | None = None,
):
super().__init__(fiber=service.fiber_pool.fibers[0] if fiber is None else fiber)
ClientGenerateBatchProcess.generate_count += 1
self.service = service
self.gen_req = gen_req
self.responder = responder
Expand All @@ -158,6 +160,12 @@ async def run(self):

# Try to add request to queue
# TODO(@zphoenixrises): Add load testing and integration tests for this.
self.prefill_batcher.worker_index = (
ClientGenerateBatchProcess.generate_count % len(self.fiber.device_names)
)
self.decode_batcher.worker_index = (
ClientGenerateBatchProcess.generate_count % len(self.fiber.device_names)
)
if not self.service.add_to_queue():
error_response = JSONResponse(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
Expand All @@ -173,6 +181,9 @@ async def run(self):
return

try:
logger.debug(
f"add_to_queue, use device {self.prefill_batcher.worker_index} for pre_fill use device {self.decode_batcher.worker_index} for decode"
)
streaming = self.gen_req.stream
self.responder.start_response()
if streaming:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig
# Initialize a page table on each device.
page_table_shape = [
self.config.alloc_page_count,
self.config.paged_kv_block_size_elements // len(devices),
self.config.paged_kv_block_size_elements,
]
for device in devices:
logging.info(
Expand Down
6 changes: 6 additions & 0 deletions shortfin/python/shortfin_apps/llm/components/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def lifecycle(app: FastApi):

from contextlib import asynccontextmanager
import logging
import os


def get_eos_from_tokenizer_config(json_path):
Expand Down Expand Up @@ -63,6 +64,11 @@ def __init__(self, args):
)
server_params.decode_config = decode_config

# use number of workers as number of logical device per physical device
os.environ["SHORTFIN_AMDGPU_LOGICAL_DEVICES_PER_PHYSICAL_DEVICE"] = str(
server_params.workers
)

# Setup system (configure devices, etc).
sysman = LlmSystemManager(
device=args.device,
Expand Down
6 changes: 4 additions & 2 deletions shortfin/src/shortfin/local/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,17 @@ struct SHORTFIN_API DeviceAddress {
// Can be used as a map key to uniquely identify this device.
uint64_t device_id() const {
return static_cast<uint64_t>(instance_ordinal) << 32 |
static_cast<uint64_t>(queue_ordinal);
static_cast<uint64_t>(queue_ordinal) << 16 |
static_cast<uint64_t>(instance_topology_address[0]);
}

// Creates a device_id() as if this device was for a different queue ordinal.
// Can be used when reassembling a device id from a queue affinity mask and
// using it to look up in a map.
uint64_t device_id_for_queue(uint32_t alternate_queue_ordinal) const {
return static_cast<uint64_t>(instance_ordinal) << 32 |
static_cast<uint64_t>(alternate_queue_ordinal);
static_cast<uint64_t>(alternate_queue_ordinal) << 16 |
static_cast<uint64_t>(instance_topology_address[0]);
}
};

Expand Down
Loading