Skip to content

[Shortfin][LLM] Add initial support for disaggregated invocations #1463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

vinayakdsci
Copy link
Contributor

Implements initial support for running prefill and decode invocations in the LLM server on independent HIP streams, by creating 2 HAL devices on the same physical device.

The patch is triggered by passing in --disaggregate to the server script invocation.

@vinayakdsci vinayakdsci force-pushed the disaggregated-invocation branch from 401da69 to 3053ae0 Compare May 16, 2025 15:46
@vinayakdsci vinayakdsci force-pushed the disaggregated-invocation branch from 3053ae0 to 88f2855 Compare June 4, 2025 14:44
@vinayakdsci vinayakdsci marked this pull request as ready for review June 4, 2025 14:45
@vinayakdsci vinayakdsci force-pushed the disaggregated-invocation branch from 88f2855 to d0e17de Compare June 4, 2025 14:51
Copy link
Contributor

@stbaione stbaione left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't look like this will behave as expected with current sharding code. Mostly due to us missing an abstraction layer, and making assumptions from number of devices

self.inference_program = self.create_program(
modules=component_modules, devices=self.sysman.ls.devices
)
print(f"{self.disaggregate=}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it's going to clash with sharded models. Seeing that we have a gap in the server CI for that, unfortunately :(

In general in this PR, it looks like you're tying the topology of --disaggregate to the number of devices specified, but that gives me concern considering we now have more going on topology wise.

Let's say I'm running a tp8 sharded model. With this setup, we would be create 8 different programs for each device, where previously, we would create one program for all 8 devices.

Gives me some concern that there be a few spots where assumptions based on number of devices could cause unintended or undefined behavior for the sharded case.

IIUC correctly, sharded would still benefit from a disaggregated approach.

We already make an assumption of the topology based on number of devices here, which we've regretted since.

All this being said, we should fix that spot in PagePool eventually and refrain from making model or server assumptions based on the number of devices.

Is this maybe looking ahead at running an unsharded model, disaggregated across multiple devices? If so, at some point, probably soon, we should introduce some kind of ModelTopology abstraction that could specify how each device should be used. Whether it's sharded or disaggregated. Which devices to use for prefill, which devices to use for decode etc. But, would be messy to do that without abstraction layer.

If that's the reason for this change, can we simplify PR to only support two hip streams on a single device, then extend to multi-device? That seems like it would play more easily with our current sharding code. I may also just be reading something wrong.

workers = [self.sysman.ls.create_worker(f"{task}-worker") for task in task_list]
fibers = [
self.sysman.ls.create_fiber(
workers[idx], devices=[devices[idx % len(devices)]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idx % len(devices) part here confuses me.

For the unsharded case w/ one device, we only create two workers, so this should only ever create two fibers for device0. If so, it'd be more explicit to just say devices[0].

For the sharded case with tp devices, this would create a fiber for device0 and for device1, even if there was no --disaggreate option passed to the server. Is that intended?

f"{self.model_params.module_name}.prefill_bs{bs}"
]
# Resolve decode entrypoints.
self.decode_functions = {}
for bs in self.model_params.decode_batch_sizes:
self.decode_functions[bs] = self.inference_program[
self.decode_functions[bs] = self.inference_program[1 % num_devices][
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here

@vinayakdsci vinayakdsci force-pushed the disaggregated-invocation branch from 862d8fc to 01931fd Compare June 5, 2025 05:31
@vinayakdsci vinayakdsci requested a review from stbaione June 5, 2025 05:33
@vinayakdsci
Copy link
Contributor Author

@stbaione Great point about the sharded case -- I had it in the back of my mind that I needed to handle that, but must have forgotten it.

I have separated all the code into new classes, and added an assertion to check that the server is running on a physical GPU when executing in disaggregated mode. The solution is not very elegant, but the best I could come up with that early in the pipeline.

Sharding should be okay now.

@vinayakdsci vinayakdsci force-pushed the disaggregated-invocation branch 2 times, most recently from ad3b24a to 9214f22 Compare June 5, 2025 05:43
@@ -37,6 +37,7 @@ def __init__(
self.input_token_ids = input_token_ids
self.prompt_length = len(input_token_ids)
self.done = sf.VoidFuture()
self.completed = sf.VoidFuture()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This either needs more comments or better names.
What is the difference between done and completed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AWoloszyn addressed in the latest commit.

@vinayakdsci vinayakdsci force-pushed the disaggregated-invocation branch from 9214f22 to e8b4652 Compare June 5, 2025 15:10
@@ -274,7 +274,7 @@ def start(self):
self.model_params,
self.prefill_functions,
self.prog_isolation,
fibers[0],
exec_fibers[0],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend using constants for the 0 and 1 rather than magic numbers (here and inference_program below).
It should make it easier to reason about.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants