Skip to content

GRPO Improvement checklist #2421

Open
Open
@RedTachyon

Description

It's alive! It's alive! (#2326)

The recipe definitely works (as in, I can run it and reach like a 60% success rate on GSM8k with a 3B model), but it's somewhat barebones and underoptimized. Here, I want to keep track of all the most important features and improvements that I think are missing. I'll probably go through this at some point, at some pace, but if anyone else wants to contribute - you can grab something from this list.

Improvements

  • Figure out how to move things from dev into the main repository (decide on the final APIs etc.)
  • A proper eval workflow, so that every once in a while we run a full eval on the test set. Alternatively, a separate evaluation recipe? (note: I actually have a working separate eval recipe that works with the same paradigm as GRPO, goes through the full test dataset, computes success/reward and saves it to a file inside the checkpoint - happy to make it a PR)
  • Adding proper (unit) tests, compliant with the normal torchtune testing workflow
  • Adding proper documentation to everything
  • Refactoring of the GRPO losses (there's also a research-y question of what loss should be used - there's some ambiguity in the paper and reference implementations)
  • More modular approach to reward computation (probably as a component with the regular OmegaConf setup)
  • Step-based checkpointing (Implement step based checkpointing #2384) - this is pretty important, since one epoch can be very long, leading to very infrequent checkpoints
  • Memory profiling and experiments on "controlled" hardware (a recipe tuned to work ~optimally on a node of 8xH100, or on a single H100, or on smaller hardware with e.g. LoRA)
  • Optimization of the default recipe - maybe we can get a big performance boost e.g. by doing ppo_epochs>1?
  • Dataset improvements (gsm8k is functional but could use some polish, there's also the MATH and DeepscaleR datasets with a similar structure that can be added)
  • A single-device version - should be pretty simple, but probably also slow. Might require gradient accumulation to properly work, I tend to get bad results with small batch sizes.
  • Try to improve generation speed by using vLLM (or something else)?
  • Probably more to be found soon

Bugs

Because of course I found a bug right after everything was finalized. There will likely be more, so this subsection might or might not be useful.

  • Right now, generate_trajectory_batched can crash when the different generations are of different size. For example, one batch of completions generated the full 512 tokens, but another one got truncated at 300 because it hit a stop token everywhere. So you have tensors of shapes [16, 512] and [16, 300], and try to concatenate them across zero-th axis - which obviously doesn't work. The tensors need to be padded to consistent length.
  • Very, very rarely, it seems that an invalid token is sampled - for example token 128011, which is an undefined special token with the standard config. When we try to decode this for the reward computation, the entire program crashes because tiktoken can't handle the unknown token. This can probably be handled by replacing undefined generated tokens with pad_id or something. As to why these tokens are ever sampled - the model probably gives them a very low probability, say 1e-7, but if you sample a new token 1e7 times, chances are, it will happen at some point.

Note to maintainers - I took the liberty to create this centralized checklist since I still have all the necessary improvements in my context window. In principle, each bullet point could be a separate issue, but that would probably be a nightmare. We can coordinate the effort around this issue, and start adding the improvements, one PR at a time.

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    community help wantedWe would love the community's help completing this issue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions