-
Notifications
You must be signed in to change notification settings - Fork 51
[Shortfin][LLM] Add initial support for disaggregated invocations #1463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
401da69
to
3053ae0
Compare
3053ae0
to
88f2855
Compare
88f2855
to
d0e17de
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't look like this will behave as expected with current sharding code. Mostly due to us missing an abstraction layer, and making assumptions from number of devices
self.inference_program = self.create_program( | ||
modules=component_modules, devices=self.sysman.ls.devices | ||
) | ||
print(f"{self.disaggregate=}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like it's going to clash with sharded models. Seeing that we have a gap in the server CI for that, unfortunately :(
In general in this PR, it looks like you're tying the topology of --disaggregate
to the number of devices specified, but that gives me concern considering we now have more going on topology
wise.
Let's say I'm running a tp8
sharded model. With this setup, we would be create 8 different programs for each device, where previously, we would create one program for all 8 devices.
Gives me some concern that there be a few spots where assumptions based on number of devices could cause unintended or undefined behavior for the sharded case.
IIUC correctly, sharded would still benefit from a disaggregated approach.
We already make an assumption of the topology based on number of devices here, which we've regretted since.
All this being said, we should fix that spot in PagePool
eventually and refrain from making model or server assumptions based on the number of devices.
Is this maybe looking ahead at running an unsharded model, disaggregated across multiple devices? If so, at some point, probably soon, we should introduce some kind of ModelTopology
abstraction that could specify how each device should be used. Whether it's sharded or disaggregated. Which devices to use for prefill, which devices to use for decode etc. But, would be messy to do that without abstraction layer.
If that's the reason for this change, can we simplify PR to only support two hip streams on a single device, then extend to multi-device? That seems like it would play more easily with our current sharding code. I may also just be reading something wrong.
workers = [self.sysman.ls.create_worker(f"{task}-worker") for task in task_list] | ||
fibers = [ | ||
self.sysman.ls.create_fiber( | ||
workers[idx], devices=[devices[idx % len(devices)]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idx % len(devices)
part here confuses me.
For the unsharded case w/ one device, we only create two workers, so this should only ever create two fibers for device0. If so, it'd be more explicit to just say devices[0]
.
For the sharded case with tp
devices, this would create a fiber for device0 and for device1, even if there was no --disaggreate
option passed to the server. Is that intended?
f"{self.model_params.module_name}.prefill_bs{bs}" | ||
] | ||
# Resolve decode entrypoints. | ||
self.decode_functions = {} | ||
for bs in self.model_params.decode_batch_sizes: | ||
self.decode_functions[bs] = self.inference_program[ | ||
self.decode_functions[bs] = self.inference_program[1 % num_devices][ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same question here
862d8fc
to
01931fd
Compare
@stbaione Great point about the sharded case -- I had it in the back of my mind that I needed to handle that, but must have forgotten it. I have separated all the code into new classes, and added an assertion to check that the server is running on a physical GPU when executing in disaggregated mode. The solution is not very elegant, but the best I could come up with that early in the pipeline. Sharding should be okay now. |
ad3b24a
to
9214f22
Compare
@@ -37,6 +37,7 @@ def __init__( | |||
self.input_token_ids = input_token_ids | |||
self.prompt_length = len(input_token_ids) | |||
self.done = sf.VoidFuture() | |||
self.completed = sf.VoidFuture() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This either needs more comments or better names.
What is the difference between done
and completed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AWoloszyn addressed in the latest commit.
9214f22
to
e8b4652
Compare
e8b4652
to
4bf080c
Compare
@@ -274,7 +274,7 @@ def start(self): | |||
self.model_params, | |||
self.prefill_functions, | |||
self.prog_isolation, | |||
fibers[0], | |||
exec_fibers[0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recommend using constants for the 0 and 1 rather than magic numbers (here and inference_program below).
It should make it easier to reason about.
Implements initial support for running prefill and decode invocations in the LLM server on independent HIP streams, by creating 2 HAL devices on the same physical device.
The patch is triggered by passing in
--disaggregate
to the server script invocation.