-
Notifications
You must be signed in to change notification settings - Fork 994
Support tree-rollout #6634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Support tree-rollout #6634
Conversation
Summary of ChangesHello @li2zhi, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request implements a novel tree-rollout feature for GRPO-based training, aiming to enhance the efficiency and performance of multi-turn reasoning. By structuring the sampling process as a tree search, it intelligently manages inference steps, reuses computations, and dynamically adapts to ensure optimal resource utilization and sample generation. This change is expected to lead to faster and more resource-effective training cycles for models leveraging multi-turn interactions. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a tree-rollout mechanism inspired by TreePO for multi-turn reasoning, which aims to reduce inference overhead and improve training efficiency. The implementation includes a new TreeRolloutScheduler, data structures for representing the tree, and divergence strategies. The changes are well-structured and the new feature is a valuable addition.
My review includes several points for improvement:
- There are a couple of imports inside functions that should be moved to the top of their respective files for better code style and to avoid potential circular dependencies.
- An unused variable
step_efficiency_metricsshould be removed. - A variable reassignment could be made clearer to improve readability.
- There is a class hierarchy issue where a synchronous method overrides an asynchronous one, which violates the Liskov Substitution Principle.
- A minor style issue in a shell script regarding a missing newline at the end of the file.
Overall, the changes are good, and addressing these points will improve the code quality and maintainability.
[fix] Adopt AI suggestions Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
| if self.args.tree_rollout: | ||
| all_outputs: List[RolloutOutput] = self.multi_turn_scheduler.run( | ||
| infer_request=all_requests, request_config=request_config) | ||
| else: | ||
| all_outputs: List[RolloutOutput] = self._engine_infer( | ||
| infer_requests=all_requests, request_config=request_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scheduler.run is executed in server mode with async engine,
For colocate, async engine will be supported soon, so we may not need to modify this part now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! I have one point I’d like to clarify.
scheduler.run defines the full multi_turn_infer workflow, but from what I can see in the current codebase, this method is not actually used. I also noticed that _colocate_multi_turn_infer is implemented in rollout_mixin, but tree-rollout cannot reuse that logic.
If we remove this branch now, how can we ensure that the custom multi-turn logic for tree-rollout scheduler will still be executed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it, thanks a lot
| vllm_use_async_engine = [self.vllm_client.use_async_engine] | ||
| use_gym_env = [self.vllm_client.use_gym_env] | ||
| enable_multi_turn = [self.vllm_client.enable_multi_turn] | ||
| enable_multi_turn = [self.vllm_client.enable_multi_turn or args.tree_rollout] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the reasons below.
| top_k=args.top_k, | ||
| repetition_penalty=args.repetition_penalty, | ||
| stop=args.stop_words, | ||
| logprobs=args.tree_rollout, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we set it in scheduler.run?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, logprobs will be set to true in #6678 if args.use_vllm true
| class TreeRolloutScheduler: | ||
| """ | ||
| Base class for multi-turn tree-rollout scheduling. | ||
| Provides default implementation for multi-turn conversation management. | ||
| CUSTOMIZATION: | ||
| Implement the required `step()` method and optionally override `check_finished()` | ||
| - Uses TreeRolloutScheduler's run() method infrastructure | ||
| - Only need to implement turn transition logic in step() | ||
| - Optionally customize termination conditions | ||
| """ | ||
| def __init__(self, vllm_client: VLLMClient = None, *args, **kwargs): | ||
| self.max_tree_width = kwargs['args'].num_generations | ||
| self.max_tree_deep = kwargs['args'].max_tree_deep | ||
| self.max_divergence = kwargs['args'].tree_max_divergence | ||
| self.divergence_strategy = kwargs['args'].tree_divergence_strategy | ||
| self.root_divergence = kwargs['args'].tree_root_divergence | ||
|
|
||
| self.vllm_client = vllm_client | ||
| self.executor = ThreadPoolExecutor(max_workers=self.max_tree_width) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please remove it to example directory, just as deepeyes does?
| class DataSampleTree: | ||
| """ | ||
| Attributes: | ||
| tree_idx (str): | ||
| for example 0/1-2/2-3/4-0, root_node = 0, next node = 1-2 infer batch 1 and index 2 sample | ||
| last_response (ChatCompletionResponseChoice): | ||
| vllm previous round output | ||
| """ | ||
| tree_idx: str | ||
| request_id: str | ||
|
|
||
| messages: Messages | ||
| logprobs: List[List[float]] = field(default_factory=list) | ||
|
|
||
| all_response_ids: List[List[int]] = field(default_factory=list) | ||
| last_response: ChatCompletionResponseChoice = None | ||
|
|
||
| token_count_per_step: List[int] = field(default_factory=list) | ||
|
|
||
| status: SampleStatus = SampleStatus.INITIAL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as scheduler, plz remove to example directory
| multi_turn_scheduler_class = multi_turns[args.multi_turn_scheduler] | ||
|
|
||
| if args.tree_rollout: | ||
| assert issubclass(multi_turn_scheduler_class, TreeRolloutScheduler) | ||
|
|
||
| multi_turn_scheduler = multi_turn_scheduler_class( | ||
| max_turns=args.max_turns, vllm_client=self.vllm_client, args=args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here — please revert it and keep the relevant arguments within the scheduler itself.
| # tree rollout | ||
| tree_rollout: bool = False | ||
| max_tree_deep: Optional[int] = 8 | ||
| tree_root_divergence: Optional[int] = 1 | ||
| tree_max_divergence: Optional[int] = 2 | ||
| tree_divergence_strategy: Optional[str] = 'logprob' | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please keep the relevant arguments within the scheduler itself.
| --beta 0.04 \ | ||
| --tree_rollout true \ | ||
| --multi_turn_scheduler tree_rollout_scheduler \ | ||
| --max_tree_deep 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please create a new directory to include the script and related plugin, making it look more complete, similar to the deepeyes example.
|
Thanks for your contribution! I've left a few some revision suggestions. Meanwhile, the following additions would be great:
|
PR type
PR information
Inspired by TreePO, we implemented multi-turn tree-rollout based on GRPO, reducing inference overhead and improving training efficiency.
In GRPO-based training, the efficiency of rollouts is critical. The current multi-turn reasoning workflow in Swift is fully streaming-based, which means several key steps are repeatedly executed across generations without reuse. This results in unnecessary duplicated computation and significantly increases inference overhead.
To address this issue, we drew inspiration from TreePO and reformulated the entire sampling phase as a tree-search process. In this design, num_generations is interpreted as the tree’s maximum width, while max_turns represents its maximum depth. Before each reasoning step, the scheduler dynamically determines whether the current node should expand and how many branches it should allocate, based on the current tree width and the branching strategy.
Furthermore, when reasoning terminates early and the number of leaf nodes has not yet reached the maximum width, a backtracking mechanism is activated. The scheduler selects earlier nodes for additional branch expansion, allowing the rollout to continue until the desired sample count is fulfilled.
This tree-rollout approach reuses intermediate computation, reduces redundant inference, and improves overall training efficiency.
Related to issue: Support TreePO #5708
Experiment results