Fix intro typos and add JAX RL libraries dropdown#11
Conversation
Fix grammatical errors and typos in the intro paragraph (learns→learn, missing "is", its/it's, notoriety, October 2025, etc.) and add a collapsible dropdown summarizing prior JAX RL repos we evaluated and why none fit our needs. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
ahmeda14960
left a comment
There was a problem hiding this comment.
Inline comments on all the [AMA] notes for discussion.
|
|
||
| Both Tinker and Marin's Sync RL converged to ~0.43 accuracy on MATH, though Marin took about 2x longer (80 steps vs. Tinker's 35 steps to reach 0.4). | ||
|
|
||
| [[AMA: Do we have evidence / an experiment for this? Seems very speculative, we can say it's a hypothesis but is there an experiment we can link?]] |
There was a problem hiding this comment.
Do we have evidence / an experiment for this? Seems very speculative — we can say it's a hypothesis but is there an experiment we can link?
There was a problem hiding this comment.
Yes, it's just a hypothesis based on the entropy plot in this wandb report
| Other contributing factors included LoRA vs. full fine-tuning requiring different learning rates, and subtle differences in sample/train log-probability divergence. | ||
|
|
||
| Still, this was a major milestone: deterministic, reproducible RL training on TPU, confirmed across 3 independent runs. | ||
| [[ AMA: I changed the mistmatch thing because tinker also acknowledges mismatch between inference and training but we probably have higher mistmatch because they figure out determinisitc generation. Also should lora vs full-tuning be a factor if we tuned our learning rate well?]] |
There was a problem hiding this comment.
I changed the mismatch thing because Tinker also acknowledges mismatch between inference and training, but we probably have higher mismatch because they figure out deterministic generation. Also — should LoRA vs full-tuning be a factor if we tuned our learning rate well?
There was a problem hiding this comment.
Yes, Tinker claims that you just can reduce LoRA LR by 10x when doing full-finetuning but only validated in a few experiments in their LoRA without regret blog. There's no way for me to verify since Marin doesn't have mature LoRA support. The 2x slower convergence could honestly be due to some unknown engineering differences between the two implementations. I tried to come up with some hypothesis but didn't have time to verify any of them.
| Still, this was a major milestone: deterministic, reproducible RL training on TPU, confirmed across 3 independent runs. | ||
| [[ AMA: I changed the mistmatch thing because tinker also acknowledges mismatch between inference and training but we probably have higher mistmatch because they figure out determinisitc generation. Also should lora vs full-tuning be a factor if we tuned our learning rate well?]] | ||
| Other contributing factors likely included LoRA vs. full fine-tuning requiring different learning rates we had to do, and a potentially larger mismatch in sample/train log-probability divergence as we are using vllm for inference and JAX for training [[9]](#ref9). | ||
| [[AMA: I took out 'deterministic' because we're still not exactly determinisitc due to mistmatch]] |
There was a problem hiding this comment.
I took out 'deterministic' because we're still not exactly deterministic due to mismatch.
There was a problem hiding this comment.
Good catch, async RL is still not deterministic due to the in-flight weight sync.
|
|
||
| - **Weight sync**: Arrow Flight transfers model weights from trainer to actor. We added bfloat16 conversion ([PR #2388](https://github.com/marin-community/marin/pull/2388)), cutting transfer bandwidth from 32GB to 16GB and transfer time from 29s to 14s. | ||
| - **In-flight updates**: Background weight sync threads prevent blocking rollouts ([PR #2325](https://github.com/marin-community/marin/pull/2325)). | ||
| [AMA: Kevin are there more PRs for weight sync that show other issues we ran into? this is the part of the code I understand the least TBD so could use your help here] |
There was a problem hiding this comment.
Kevin — are there more PRs for weight sync that show other issues we ran into? This is the part of the code I understand the least, so could use your help here.
There was a problem hiding this comment.
I couldn't find more PRs for weight sync. I remember Chris shared some challenges he bumped into while implementing this feature but they are all in the commit history of the single large RL PR that got broken up.
| - **Weight sync**: On-policy RL works best when the actor is sampling with weights close to the trainer's current policy, so we need to push updated weights to rollout workers frequently. At LLM scale this is hard because each sync moves tens of GB, and if it is too slow you either stall rollout generation waiting for fresh weights or keep sampling from stale policies. | ||
| We were able to improve this by added bfloat16 conversion ([PR #2388](https://github.com/marin-community/marin/pull/2388)), cutting transfer bandwidth in half from 32GB to 16GB and transfer time from 29s to 14s. | ||
| - **In-flight updates**: In an async setup, the trainer wants to publish new weights frequently, but if the actor pauses to wait for every update then inference still becomes part of the critical path. That creates a bad tradeoff between stale policies and idle inference time. We fixed this by adding background weight-sync threads, so rollout workers wait only for the first weights and then continue sampling while newer weights are transferred and hot-reloaded in the background ([PR #2325](https://github.com/marin-community/marin/pull/2325)). | ||
| [ AMA: this isn't really an RL thing... I would remove this it's more of an idiosyncracy of the fact we're using Ray which is crappy, we should take this out] |
There was a problem hiding this comment.
This isn't really an RL thing — I would remove this. It's more of an idiosyncrasy of the fact we're using Ray, which is crappy. We should take this out.
| vLLM's docs specify that `top_k=-1` means "consider all tokens," but this code converted `-1` to `1`, selecting only the highest-probability token regardless of temperature. We filed a bug report ([tpu-inference #1386](https://github.com/vllm-project/tpu-inference/issues/1386)) which suggested a fix that got merged. | ||
|
|
||
| This single bug explained the nondeterminism: tiny sampling differences under greedy sampling, amplified over dozens of RL steps, caused the runs to diverge. | ||
| [AMA: do we have an example of the runs being stable after this? right now we're just saying 'trust me bro' would be nice to show a plot pre fix and post fix!] |
There was a problem hiding this comment.
Do we have an example of the runs being stable after this fix? Right now we're just saying 'trust me bro' — would be nice to show a plot pre-fix and post-fix!
There was a problem hiding this comment.
We don't have a clean before vs after the vllm top-k fix but we have a run after all the async RL sub-PRs are merged, which shows the MATH-500 performance converged to a stable 0.46 (+/-0.02).
github issue comment, wandb run
| We addressed this by implementing a robust combinatorial Pass@k estimator (following the approach from Codex [[12]](#ref12), [lighteval](https://github.com/huggingface/lighteval), and DeepMath [[13]](#ref13)) and increasing the eval sample size K per task to 32 ([PR #2493](https://github.com/marin-community/marin/pull/2493)). | ||
|
|
||
| Training Qwen 2.5 7B on [DeepMath-103K](https://huggingface.co/datasets/PRIME-RL/DeepMath-103K) showed steady Pass@16 gains (reaching 0.40) but Pass@1 remained near zero after 40 steps ([PR #2441](https://github.com/marin-community/marin/pull/2441)). | ||
| [AMA do we have more concrete metrics to back this up? we should be clear when we have a hypothesis / when we are conjecturing |
There was a problem hiding this comment.
Do we have more concrete metrics to back this up? We should be clear when we have a hypothesis vs. when we are conjecturing.
There was a problem hiding this comment.
No other metrics. My intuition is that pass@1 is noisier and harder than pass@16 so there's a threshold that the model needs to go beyond in pass@16 in order for us to see stable improvement in pass@1.
| 5. **Infrastructure is half the battle.** Memory, logging, weight sync, and dependencies required constant attention. | ||
| 1. **Establish baselines first.** As much as we would have loved to zero-shot an Async RL pipeline, we are grateful for Tinker as the baselines saved weeks of debugging by helpiung us sanity check our Sync RL pipeline. | ||
| 2. **Base model choice > algorithm tuning.** Llama 1B failed GSM8K while Llama 8B solved it in 1 step. This highlights the importance of mid-training / pre-training which we will further explore | ||
| 3. **Reproduce before innovating.** Code-R1 and MATH baselines caught subtle environment and prompt bugs. [AMA This is kind of the same as point one] |
There was a problem hiding this comment.
This is kind of the same as point one — should we merge them or differentiate more clearly?
There was a problem hiding this comment.
Yeah, I agree they could be rephrased to be more distinguished from each other, though I think 1 and 4 are sufficiently different. 1 is about using baseline as a guide during development of a new library while 4 is about evaluation metrics.
| 1. **Establish baselines first.** As much as we would have loved to zero-shot an Async RL pipeline, we are grateful for Tinker as the baselines saved weeks of debugging by helpiung us sanity check our Sync RL pipeline. | ||
| 2. **Base model choice > algorithm tuning.** Llama 1B failed GSM8K while Llama 8B solved it in 1 step. This highlights the importance of mid-training / pre-training which we will further explore | ||
| 3. **Reproduce before innovating.** Code-R1 and MATH baselines caught subtle environment and prompt bugs. [AMA This is kind of the same as point one] | ||
| 4. **Evaluation needs care.** AIME25's 30 questions require multi-seed evals and careful estimators [AMA say more about this, what were we doing before? how does the new estimator make things better] [we should add a note about different prompt formatters for math 500!]. |
There was a problem hiding this comment.
Can we say more about this? What were we doing before? How does the new estimator make things better? Also, we should add a note about different prompt formatters for MATH-500.
There was a problem hiding this comment.
True, I can say more: Before to estimate pass@k, we naively sample k trials from a pool of 16 trials. This is very unstable and noisy for small k, like k=1, because the pass rate has high variance across subsamples. With the estimator introduced in this PR, all 16 trials are used to estimate pass@k through a combinatorial formula.
| - GSM8K | ||
| - MATH500 | ||
|
|
||
| (this is kind of obvious to RL practitioners and probably doesn't add much, we can just skip to interesting results) |
There was a problem hiding this comment.
Please see this note. I added this for your context, and I added more context here. I also made it clear in the following paragraph the model sizes and tasks, so I don't think we need line 49 - 53 anymore
- Add markdown="1" so Kramdown renders markdown inside <details> - Bold and enlarge the dropdown summary text with more whitespace - Remove GitHub star counts from all library entries Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
<details>dropdown summarizing prior JAX RL repos evaluated (Tunix, Brax, RLax, PureJaxRL, Stoix, Rejax/EvoRL, Apple RLAX) and why none were a fit, based onrl_libariers.mdTest plan
🤖 Generated with Claude Code