[migration] copy old docs, examples, integrations, scripts#1133
Conversation
There was a problem hiding this comment.
Code Review
This pull request copies a large number of documentation files, examples, and scripts. The new additions provide a wealth of examples for various training scenarios, including different algorithms, integrations, and large-scale training setups. My review focuses on ensuring the new scripts and code are correct and maintainable. I've identified a few areas for improvement, including a bug in a data preprocessing script where a command-line argument is ignored, and some code duplication across example files that could be refactored for better maintainability. I also found a minor formatting issue in the documentation.
| # Define input and output files | ||
| DATA_DIR = Path.home() / "data/dapo" |
There was a problem hiding this comment.
The DATA_DIR is hardcoded, but the calling script prepare_dapo_data.sh passes a --data-dir argument which is currently ignored. To make this script more flexible and align with its usage, you should use argparse to handle command-line arguments. For example:
import argparse
from pathlib import Path
parser = argparse.ArgumentParser()
parser.add_argument("--data-dir", default=str(Path.home() / "data/dapo"))
args = parser.parse_args()
DATA_DIR = Path(args.data_dir)| class DAPOTrainer(RayPPOTrainer): | ||
| """ | ||
| Custom trainer for DAPO. | ||
|
|
||
| Overrides the postprocess_generator_output method to additionally apply soft overlong punishment to rewards. | ||
| """ | ||
|
|
||
| @torch.no_grad() | ||
| def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput: | ||
| """ | ||
| Overrides the postprocess_generator_output method to additionally apply DAPO specific soft overlong punishment to rewards. | ||
|
|
||
| Args: | ||
| generator_output: GeneratorOutput | ||
| uids: List[str] | ||
|
|
||
| Returns: | ||
| GeneratorOutput | ||
| """ | ||
| overlong_buffer_len = self.cfg.trainer.algorithm.overlong_buffer.len | ||
| overlong_buffer_penalty_factor = self.cfg.trainer.algorithm.overlong_buffer.penalty_factor | ||
| # modify rewards here | ||
| response_ids = generator_output["response_ids"] | ||
| rewards = generator_output["rewards"] | ||
|
|
||
| assert not isinstance(rewards[0], list), "we assume verifiable sequence level rewards here" | ||
|
|
||
| # get the response length | ||
| response_lengths = [len(response) for response in response_ids] | ||
|
|
||
| # get the max context length | ||
| # NOTE: this is only valid for single turn generation | ||
| max_response_length = self.cfg.generator.sampling_params.max_generate_length | ||
|
|
||
| # apply soft overlong punishment | ||
| for i, response_length in enumerate(response_lengths): | ||
| # max_exceed_length is the beginning of the overlong buffer | ||
| max_exceed_length = max_response_length - overlong_buffer_len | ||
| # if the response is within the overlong buffer, apply the penalty | ||
| if response_length > max_exceed_length and response_length <= max_response_length: | ||
| exceed_length = response_length - max_exceed_length | ||
| penalty = exceed_length / overlong_buffer_len * overlong_buffer_penalty_factor | ||
|
|
||
| rewards[i] -= penalty | ||
| # if the response is outside the overlong buffer, set the reward to 0 | ||
| elif response_length > max_response_length: | ||
| # if self.cfg.generator.apply_overlong_filtering is true, loss masks are already set to 0 for these responses | ||
| rewards[i] = 0.0 | ||
|
|
||
| generator_output["rewards"] = rewards | ||
|
|
||
| # use base class impl for metrics and per-token reward conversion | ||
| return super().postprocess_generator_output(generator_output, uids) | ||
|
|
There was a problem hiding this comment.
The DAPOTrainer class is very similar to the one in skyrl/examples/train/algorithms/dapo/main_dapo.py. To improve maintainability and avoid code duplication, consider moving this class to a shared module and importing it in both places. The small differences in logic could be handled with configuration flags.
| class DAPOTrainer(RayPPOTrainer): | ||
| """ | ||
| Custom trainer for DAPO. | ||
|
|
||
| Overrides the postprocess_generator_output method to additionally apply soft overlong punishment to rewards. | ||
| """ | ||
|
|
||
| @torch.no_grad() | ||
| def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput: | ||
| """ | ||
| Overrides the postprocess_generator_output method to additionally apply DAPO specific soft overlong punishment to rewards. | ||
|
|
||
| Args: | ||
| generator_output: GeneratorOutput | ||
| uids: List[str] | ||
|
|
||
| Returns: | ||
| GeneratorOutput | ||
| """ | ||
| overlong_buffer_len = self.cfg.trainer.algorithm.overlong_buffer.len | ||
| overlong_buffer_penalty_factor = self.cfg.trainer.algorithm.overlong_buffer.penalty_factor | ||
| # modify rewards here | ||
| prompt_token_ids = generator_output["prompt_token_ids"] | ||
| response_ids = generator_output["response_ids"] | ||
| rewards = generator_output["rewards"] | ||
|
|
||
| assert not isinstance(rewards[0], list), "we assume verifiable sequence level rewards here" | ||
|
|
||
| # get the prompt length | ||
| prompt_lengths = [len(prompt) for prompt in prompt_token_ids] | ||
|
|
||
| # get the response length | ||
| response_lengths = [len(response) for response in response_ids] | ||
|
|
||
| # get the max context length | ||
| max_context_length = ( | ||
| self.cfg.generator.max_input_length + self.cfg.generator.sampling_params.max_generate_length | ||
| ) | ||
|
|
||
| # apply soft overlong punishment | ||
| for i, (prompt_length, response_length) in enumerate(zip(prompt_lengths, response_lengths)): | ||
| # max_exceed_length is the beginning of the overlong buffer | ||
| max_exceed_length = max_context_length - overlong_buffer_len - prompt_length | ||
| # if the response is within the overlong buffer, apply the penalty | ||
| if response_length > max_exceed_length and response_length <= max_context_length - prompt_length: | ||
| exceed_length = response_length - max_exceed_length | ||
| penalty = exceed_length / overlong_buffer_len * overlong_buffer_penalty_factor | ||
|
|
||
| rewards[i] -= penalty | ||
| # if the response is outside the overlong buffer, set the reward to 0 | ||
| elif response_length > max_context_length - prompt_length: | ||
| # if self.cfg.generator.apply_overlong_filtering is true, loss masks are already set to 0 for these responses | ||
| rewards[i] = 0.0 | ||
|
|
||
| generator_output["rewards"] = rewards | ||
|
|
||
| # use base class impl for metrics and per-token reward conversion | ||
| return super().postprocess_generator_output(generator_output, uids) | ||
|
|
There was a problem hiding this comment.
The DAPOTrainer class is very similar to the ones in skyrl/examples/train/algorithms/dapo/main_dapo.py and skyrl/examples/train/flash_rl/main_dapo_flashrl.py. To improve maintainability and avoid code duplication, consider refactoring this into a single, more configurable DAPOTrainer class in a shared module. The small differences in logic could be handled with configuration flags.
|
|
||
| 1. **Disk-based synchronization**: LoRA adapters are saved to disk and reloaded rather than synchronized in-memory. | ||
|
|
||
| 4. **Single adapter per model**: Currently, only one LoRA adapter can be active per model at a time. |
There was a problem hiding this comment.
| except Exception as e: | ||
| error = str(e) | ||
| observation = None | ||
| reward = -1 |
There was a problem hiding this comment.
🔴 UnboundLocalError: done variable not set when exception occurs in OpenEnv.step()
When _get_openenv_action or self.env.step(action) raises an exception and max_turns_reached is False, the code falls through to the else branch at line 162 and references done at line 184, which was never assigned.
Root Cause
The done variable is only assigned inside the try block at line 153 (done = result.done). When an exception is caught at line 154, done is never set. The except block sets error, observation, and reward, but not done.
If max_turns_reached is False (line 159), execution reaches line 184 where done is used in BaseTextEnvStepOutput(... done=done ...), causing an UnboundLocalError at runtime.
This will happen whenever the LLM generates an action that can't be parsed (e.g., no <action> tags) and the environment hasn't reached max turns yet.
Impact: The environment crashes with UnboundLocalError instead of gracefully handling the error, which could cause the entire training step to fail.
| except Exception as e: | |
| error = str(e) | |
| observation = None | |
| reward = -1 | |
| except Exception as e: | |
| error = str(e) | |
| observation = None | |
| reward = -1 | |
| done = False |
Was this helpful? React with 👍 or 👎 to provide feedback.
| action = matches[-1] if len(matches) > 0 else None | ||
|
|
||
| if not action: | ||
| raise ValueError(f"No action found in action string: {action}") |
There was a problem hiding this comment.
🟡 Error message always shows None instead of original action string in _get_openenv_action
When no <action> tags are found in the LLM output, the error message at line 110 is useless because it references the already-overwritten action variable.
Root Cause
At line 107, the action parameter is reassigned: action = matches[-1] if len(matches) > 0 else None. When no matches are found, action becomes None. Then at line 110, the error message f"No action found in action string: {action}" always prints "No action found in action string: None" instead of showing the original LLM output that failed to parse.
The original action string (the LLM's full response) is lost, making debugging difficult.
Impact: Debugging is significantly harder because the error message doesn't show what the LLM actually generated.
| action = matches[-1] if len(matches) > 0 else None | |
| if not action: | |
| raise ValueError(f"No action found in action string: {action}") | |
| parsed_action = matches[-1] if len(matches) > 0 else None | |
| if not parsed_action: | |
| raise ValueError(f"No action found in action string: {action}") | |
| action = parsed_action |
Was this helpful? React with 👍 or 👎 to provide feedback.
Copy over old docs, examples, integrations, scripts
WIP to make sure all of these run against the new refactored code, will need to change import paths and test!
cc: @CharlieFRuan