Skip to content

Conversation

@keshavb96
Copy link
Contributor

The purpose of the PR is to refactor the existing JAX <> vLLM bridge to expose and interface that's more easily usable by user who want to avoid using Tunix, the main change is the addition of a framework agnostic interface - VLLMRolloutEngine that an external RL framework can inherit / wrap around to access as the entry point to the functionality of the bridge

"""

def pad_left(seq: List[int], length: int, pad_value: int) -> List[int]:
seq = seq[:length] # Truncate if too long
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, silently truncating feels wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to not silently truncate anymore

"""Get all generated token ID sequences."""
return [c.token_ids for c in self.completions]

def to_arrays(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be unused? Am I missing something?

Copy link
Contributor Author

@keshavb96 keshavb96 Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a simpler debug example of an RL loop that doesn't use Tunix that uses it, I just pushed that example as well

- Dict[str, jax.Array]: Direct flattened params
- flax.nnx.State: Flax state object
- flax.nnx.Module: Flax module (state extracted automatically)
block: If True, wait for transfer completion (always True currently).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this flag there if it doesn't do anything?

assert len(original) <= length, f"Sequence too long: {len(original)} > {length}"
return original + [pad_value] * (length - len(original))

for i, completion in enumerate(output.completions):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leftover debug output?

input_tokens = []
output_tokens = []

def pad_to_left(original: List[int], length: int, pad_value: int) -> List[int]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's another implementation of this in api/types.py.

@keshavb96 keshavb96 marked this pull request as ready for review January 26, 2026 20:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants