-
Notifications
You must be signed in to change notification settings - Fork 68
Refactor of vLLM bridge #1914
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Refactor of vLLM bridge #1914
Conversation
| """ | ||
|
|
||
| def pad_left(seq: List[int], length: int, pad_value: int) -> List[int]: | ||
| seq = seq[:length] # Truncate if too long |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, silently truncating feels wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to not silently truncate anymore
| """Get all generated token ID sequences.""" | ||
| return [c.token_ids for c in self.completions] | ||
|
|
||
| def to_arrays( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to be unused? Am I missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a simpler debug example of an RL loop that doesn't use Tunix that uses it, I just pushed that example as well
| - Dict[str, jax.Array]: Direct flattened params | ||
| - flax.nnx.State: Flax state object | ||
| - flax.nnx.Module: Flax module (state extracted automatically) | ||
| block: If True, wait for transfer completion (always True currently). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this flag there if it doesn't do anything?
| assert len(original) <= length, f"Sequence too long: {len(original)} > {length}" | ||
| return original + [pad_value] * (length - len(original)) | ||
|
|
||
| for i, completion in enumerate(output.completions): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leftover debug output?
| input_tokens = [] | ||
| output_tokens = [] | ||
|
|
||
| def pad_to_left(original: List[int], length: int, pad_value: int) -> List[int]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's another implementation of this in api/types.py.
The purpose of the PR is to refactor the existing JAX <> vLLM bridge to expose and interface that's more easily usable by user who want to avoid using Tunix, the main change is the addition of a framework agnostic interface -
VLLMRolloutEnginethat an external RL framework can inherit / wrap around to access as the entry point to the functionality of the bridge