-
Notifications
You must be signed in to change notification settings - Fork 51
[Shortfin][LLM] Add initial support for disaggregated invocations #1463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
459d31f
1bd649f
0919dd8
897ad06
cce17ce
29cc3a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,14 +17,15 @@ def lifecycle(app: FastApi): | |
from .config_struct import ModelParams, ServerParams | ||
from .token_selection_strategy import DecodeConfig | ||
from .manager import LlmSystemManager | ||
from .service import LlmGenerateService | ||
from .service import LlmGenerateService, LlmGenerateDisaggregatedService | ||
from .tokenizer import Tokenizer | ||
from typing import TYPE_CHECKING | ||
from fastapi import FastAPI | ||
|
||
|
||
from contextlib import asynccontextmanager | ||
import logging | ||
import os | ||
|
||
|
||
def get_eos_from_tokenizer_config(json_path): | ||
|
@@ -63,6 +64,19 @@ def __init__(self, args): | |
) | ||
server_params.decode_config = decode_config | ||
|
||
service_cls = LlmGenerateService | ||
if args.disaggregate: | ||
# To not run into complications with sharded models, assert that the server is | ||
# being run only on one physical device. | ||
rocr_visible_devices = os.environ.get("ROCR_VISIBLE_DEVICES") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Devices can be set with |
||
assert ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This still needs to be fixed. I think it'd be better to check I don't set ROCR_VISIBLE_DEVICES when running the server, unless someone set the system to something like SPX/DPX/etc., and just specify devices with User could also set ROCR_VISIBLE_DEVICES to multiple devices, but run the shortfin server with only one devices We should be reading And if |
||
rocr_visible_devices is not None and len(rocr_visible_devices) <= 2 | ||
), "Running disaggregated prefill on HIP streams is supported only when running on one physical device. Set `ROCR_VISIBLE_DEVICES`=<device_id>." | ||
# Setup two logical devices on one physical device to disaggregate | ||
# prefill and decode invocations to distinct streams. | ||
os.environ["SHORTFIN_AMDGPU_LOGICAL_DEVICES_PER_PHYSICAL_DEVICE"] = "2" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should not be using environment variables to plumb the logical devices. Look into the shortfin native library for where you can specify this variable and pass it through. |
||
service_cls = LlmGenerateDisaggregatedService | ||
|
||
# Setup system (configure devices, etc). | ||
sysman = LlmSystemManager( | ||
device=args.device, | ||
|
@@ -78,7 +92,7 @@ def __init__(self, args): | |
tokenizer = Tokenizer.from_tokenizer_json_file( | ||
args.tokenizer_json, eos_token=eos_token | ||
) | ||
service = LlmGenerateService( | ||
service = service_cls( | ||
name="default", | ||
sysman=sysman, | ||
tokenizer=tokenizer, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to construct multiple awaitables then perform a gather and await. Seeing a loop on pending needlessly uses the python interpreter manage instead of relying on asyncio features.