Skip to content

Add support for vllm sampler kwargs.#1169

Open
NicoGrande wants to merge 1 commit intomainfrom
nicogrande/add-vllm-sampler-kwargs
Open

Add support for vllm sampler kwargs.#1169
NicoGrande wants to merge 1 commit intomainfrom
nicogrande/add-vllm-sampler-kwargs

Conversation

@NicoGrande
Copy link
Collaborator

Add support for sampling kwargs to be passed in via vLLMConfig. This mirrors the recent engine_kwargs argument added to vLLMConfig.

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and all unit tests pass.
  • I have added all appropriate doc-strings/documentation.
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have signed the Contributor License Agreement.
  • I have followed Contribution Guidelines.

@gemini-code-assist
Copy link

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the configurability of vLLM integrations by providing a dedicated mechanism to pass sampling-related keyword arguments directly through the configuration. This change allows for more granular control over generation parameters, aligning the sampler configuration with the existing engine configuration capabilities. The update propagates these new configuration options across the VllmConfig and RolloutConfig structures, making it easier to customize vLLM's behavior in various contexts.

Highlights

  • VllmConfig Enhancement: Introduced a new sampler_kwargs field to VllmConfig to allow direct passing of vLLM sampler arguments such as temperature or stop tokens, mirroring the existing engine_kwargs.
  • VllmSampler Integration: Modified the VllmSampler's __call__ method to process and apply the sampler_kwargs from its configuration to the underlying vLLM SamplingParams object, ensuring these parameters are used during generation.
  • Rollout Configuration Update: Added rollout_vllm_sampler_kwargs to RolloutConfig in base_rollout.py to enable configuration of vLLM sampling parameters specifically for rollout processes.
  • VllmRollout Initialization: Integrated the new rollout_vllm_sampler_kwargs into the VllmConfig instantiation within VllmRollout, ensuring that rollout-specific sampler arguments are correctly passed to the vLLM setup.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • tunix/generate/vllm_sampler.py
    • Added sampler_kwargs field to VllmConfig with a default empty dictionary.
    • Updated the __call__ method to merge self.config.sampler_kwargs with any provided kwargs.
    • Modified the __call__ method to iterate through combined kwargs and set them as attributes on sampling_params using setattr.
    • Adjusted logging messages to reflect the new handling of sampler kwargs.
  • tunix/rl/rollout/base_rollout.py
    • Added rollout_vllm_sampler_kwargs field to RolloutConfig for vLLM sampler arguments.
    • Updated the comment for rollout_vllm_kwargs to specify it's for the vLLM engine constructor.
  • tunix/rl/rollout/vllm_rollout.py
    • Modified the VllmRollout constructor to pass rollout_vllm_sampler_kwargs from RolloutConfig to the VllmConfig instance.
Activity
  • The author has completed all items in the provided checklist, including adding necessary unit tests, verifying code integrity, adding documentation, rebasing on the main branch, signing the Contributor License Agreement, and following contribution guidelines.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@NicoGrande NicoGrande force-pushed the nicogrande/add-vllm-sampler-kwargs branch from 3675c58 to b418aa7 Compare February 26, 2026 22:14
Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for passing arbitrary sampler keyword arguments to the vLLM sampler, which increases its flexibility. The changes are well-structured, introducing a sampler_kwargs field in VllmConfig and plumbing it through to the sampler's __call__ method. My review includes a couple of suggestions for improvement in tunix/generate/vllm_sampler.py related to logging verbosity and exception handling to better align with the repository's style guide for robustness and maintainability.

Copy link
Collaborator

@wang2yn84 wang2yn84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for this?

)

# vLLM sampler args that can be directly passed in without additional processing, e.g. temperature, stop etc.
sampler_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename to sampling_kwargs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

rollout_vllm_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)

# Additional keyword arguments forwarded directly to the vLLM sampler.
rollout_vllm_sampler_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sampling?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@NicoGrande
Copy link
Collaborator Author

Can you add a test for this?

I added a couple of unit tests. LMK what you think.

self.repo_id, enable_lora=self.enable_lora
)

base_utils.show_hbm_usage("After loading tunix model")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this in the test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! Will remove!

self.repo_id, enable_lora=self.enable_lora
)

base_utils.show_hbm_usage("After loading tunix model")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


def test_vllm_sampler_sampling_kwargs(self):
"""Test that sampling kwargs are correctly applied to sampling_params."""
tunix_model, _ = self.load_llama3_model(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are not testing the correctness of output, shall we put dummy_model_creator that Tunix offered instead of the real HF model?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

sampling_params.seed = seed

if kwargs:
self.config.sampling_kwargs.update(kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

engine_kwargs ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one should be sampling_kwargs since its not going to the llm constructor but to the sampling_params object instead

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the kwargs, where does it come from?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those are the kwargs that are coming from the the generate() method. They can be passed in dynamically.

if self.config.sampling_kwargs:
try:
sampling_params.update(**kwargs)
logging.log_first_n(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The vllm config init will not be called more than once?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No - I think it will only be called once. The kwargs provided to the call method could change theoretically which is why we want keep the update call here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I mean if it will only be called once, maybe we don't need to use logging.log_first_n?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! Will update this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants