diff --git a/_includes/head-custom.html b/_includes/head-custom.html new file mode 100644 index 0000000..2300e4b --- /dev/null +++ b/_includes/head-custom.html @@ -0,0 +1,29 @@ + diff --git a/_layouts/default.html b/_layouts/default.html index c0ab71d..58e89d3 100644 --- a/_layouts/default.html +++ b/_layouts/default.html @@ -12,7 +12,7 @@ {% include head-custom.html %} -
+Marin Sync RL vs. Tinker on MATH-500. Both converge to ~0.43, but Tinker gets there in half the steps. (WandB report)
+Still, this was a major milestone: reproducible RL training on TPU, confirmed across 3 independent runs. + + + +Tinker (LoRA) vs. Marin (Full FT) on MATH-500. Left: both converge to ~0.43 test accuracy, but Tinker crosses 0.40 at step 29 vs. Marin at step 81 (dashed vertical lines). Center: Marin's format accuracy starts at 0.47 and takes ~80 steps to reach 0.80 (dashed lines), suggesting full fine-tuning disrupted format-following. Right: entropy is similar between both runs, ruling out exploration differences as the cause. (WandB report)
## Async RL: speeding up RL by decoupling training and inference @@ -40,14 +61,14 @@ Still, this was a major milestone: deterministic, reproducible RL training on TPSync RL runs each stage sequentially; Async RL runs the trainer (Levanter) and actor (vLLM) concurrently with weights synced via Arrow Flight.
-Synchronous RL is simple but slow---each stage (generate, train, eval) waits for the previous one. -In December, we built an asynchronous pipeline where the trainer (Levanter) and actor (vLLM) run concurrently, with model weights synchronized via [Arrow Flight](https://arrow.apache.org/docs/format/Flight.html). +Synchronous RL was a good simple first step. Unfortunately, it is quite slow as each stage of (generate, train, eval) completes sequentially which seriously hinders throughput. We know from prior work it is possible to build a performant Async RL system [[7]](#ref7) so that was our next goal. -The transition required solving several infrastructure challenges: +In December, we built an asynchronous pipeline where the trainer (Levanter) and actor (vLLM) run concurrently, with model weights synchronized via [Arrow Flight](https://arrow.apache.org/docs/format/Flight.html). This transition required solving several infrastructure challenges: + +- **Weight sync**: On-policy RL typically assumes the actor is sampling with same weights as the trainer. For async RL the goal is then to push updated weights to rollout workers frequently so the generations remain as 'on-policy' as possible. At LLM scale this is hard because each sync moves tens of GB, and if it is too slow you either stall rollout generation waiting for fresh weights or keep sampling from stale policies. + We were able to improve this by adding bfloat16 conversion ([PR #2388](https://github.com/marin-community/marin/pull/2388)), cutting transfer bandwidth in half from 32GB to 16GB and transfer time from 29s to 14s. +- **In-flight updates**: In an async setup, the trainer wants to publish new weights frequently, but if the actor pauses to wait for every update then inference still becomes part of the critical path. That creates a bad tradeoff between stale policies and idle inference time. We fixed this by adding background weight-sync threads, so rollout workers wait only for the first weights and then continue sampling while newer weights are transferred and hot-reloaded in the background ([PR #2325](https://github.com/marin-community/marin/pull/2325)). -- **Weight sync**: Arrow Flight transfers model weights from trainer to actor. We added bfloat16 conversion ([PR #2388](https://github.com/marin-community/marin/pull/2388)), cutting transfer bandwidth from 32GB to 16GB and transfer time from 29s to 14s. -- **In-flight updates**: Background weight sync threads prevent blocking rollouts ([PR #2325](https://github.com/marin-community/marin/pull/2325)). -- **Resource contention**: Coordinator actors deadlocked on CPU allocation---fixed by setting `num_cpus=0` ([PR #2350](https://github.com/marin-community/marin/pull/2350)). The result: async RL matched sync RL quality (0.26 to 0.50 on MATH-500 in 10 steps) with a **1.21x speedup**: @@ -61,19 +82,19 @@ The result: async RL matched sync RL quality (0.26 to 0.50 on MATH-500 in 10 ste ## Tracking down a mysterious divergence -But something was wrong. Two identical async RL runs diverged wildly after dozens of steps---one peaked at 0.514 accuracy, the other at only 0.482 before collapsing to 0.36. Training metrics (loss, KL, rewards) were consistent between runs. The divergence appeared only at inference time. ([WandB report](https://wandb.ai/marin-community/marin_post_training/reports/Async-RL-with-in-flight-updates-is-nondeterministic-with-vastly-different-test-results-and-policy-behavior-across-runs--VmlldzoxNTQzMzg5NA)) +Despite our initial success, we noticed something was wrong as training progressed. Two identical async RL runs diverged wildly after dozens of steps---one peaked at 0.514 accuracy, the other at only 0.482 before collapsing to 0.36. Training metrics (loss, KL, rewards) were consistent between runs. The divergence appeared only at inference time. ([WandB report](https://wandb.ai/marin-community/marin_post_training/reports/Async-RL-with-in-flight-updates-is-nondeterministic-with-vastly-different-test-results-and-policy-behavior-across-runs--VmlldzoxNTQzMzg5NA))  -2 async RL runs with identical configs diverged wildly after a few steps
+Two identical async RL runs diverge on eval accuracy (left, red shaded region) while train accuracy remains indistinguishable (right). Bottom row shows EMA smoothing (α=0.7) to make the divergence clearer. The bug only affected sampling at inference time. (WandB report)
-We systematically debugged the discrepancy ([#2260](https://github.com/marin-community/marin/pull/2260)): +We systematically debugged the discrepancy, investigating the following culprits ([#2260](https://github.com/marin-community/marin/pull/2260)): 1. **Token limit?** Truncating outputs to match Tinker's `max_tokens=512` still left accuracy far above Tinker's. Not the cause. 2. **Temperature?** Running Tinker with `temp=0.0` instead of `1.0` jumped accuracy from 0.294 to 0.442---a 51% improvement. Key clue. -3. **TPU vs GPU?** Running vLLM with `temp=0` and `temp=1` on both platforms revealed the root cause: on GPU, accuracy dropped from 42.1% to 28.3% as expected. On TPU: 40.9% vs 41.7%---**no difference at all**. +3. **TPU vs GPU?** Running vLLM with `temp=0` and `temp=1` on both platforms revealed the root cause: on GPU, accuracy dropped from 42.1% to 28.3% as expected, but on TPU: 40.9% vs 41.7%---**no difference at all**. -**Diagnosis: vLLM on TPU was silently ignoring temperature.** All prior async RL evaluations had been effectively greedy. +After a tireless search, we found the root cause **vLLM on TPU was silently ignoring temperature.** All prior async RL evaluations had been effectively greedy! We traced the bug to `input_batch.py` in the [tpu-inference](https://github.com/vllm-project/vllm/tree/main/vllm) codebase: @@ -83,31 +104,37 @@ if top_k <= 0 or top_k >= vocab_size: top_k = 1 # BUG: forces greedy! ``` -vLLM's docs specify that `top_k=-1` means "consider all tokens," but this code converted `-1` to `1`, selecting only the highest-probability token regardless of temperature. We filed a bug report ([tpu-inference #1386](https://github.com/vllm-project/tpu-inference/issues/1386)) which suggested a fix that got merged. +vLLM's docs specify that `top_k=-1` means "consider all tokens," but this code converted `-1` to `1`, selecting only the highest-probability token regardless of temperature. We filed a bug report ([tpu-inference #1386](https://github.com/vllm-project/tpu-inference/issues/1386)) and proposed a fix that got merged. + +This bug explained the nondeterminism: under effective greedy sampling, tiny floating-point differences in logit ordering can break ties differently across runs, and these small deviations compound over dozens of RL steps. Separately, we also caught a [loss normalization regression](https://github.com/marin-community/marin/pull/2039#issuecomment-3764238643) where switching from global token normalization to per-example normalization in the DAPO loss caused a 13% accuracy drop---short responses were being overweighted relative to detailed reasoning chains. -This single bug explained the nondeterminism: tiny sampling differences under greedy sampling, amplified over dozens of RL steps, caused the runs to diverge. +After fixing both issues, MATH-500 performance converged to a stable 0.46 (+/-0.02) over 186 steps ([WandB run](https://wandb.ai/marin-community/marin_post_training/runs/llama-3.1-8bi-math-lr=2e-6-bs=1024-20260117-110441-rollout-0)): + + +After fixing the vLLM top-k bug and loss normalization regression, MATH-500 Pass@1 reaches 0.46 within 10 steps and remains stable (mean=0.45, ±2σ=0.028) over 186 steps of training.
## Expanding to new models and benchmarks ### Qwen 2.5 support -Supporting Qwen 2.5 in the RL pipeline ([PR #2446](https://github.com/marin-community/marin/pull/2446), [PR #2456](https://github.com/marin-community/marin/pull/2456), [PR #2458](https://github.com/marin-community/marin/pull/2458)) required solving three issues: +Qwen 2.5 is very popular in the community for post-training [[11]](#ref11), so we began an effort to integrate it. Supporting Qwen 2.5 in the RL pipeline ([PR #2446](https://github.com/marin-community/marin/pull/2446), [PR #2456](https://github.com/marin-community/marin/pull/2456), [PR #2458](https://github.com/marin-community/marin/pull/2458)) required solving three issues: the model wasn't registered in tpu-inference (forcing a slow PyTorch fallback), the weight sync crashed due to different `q_proj` reshape logic, and Qwen's padded vocabulary (152064 tokens for hardware alignment) conflicted with Levanter's automatic vocab resizing. +After resolving these issues, we had both a working RL pipeline and support for Qwen 2.5, which we believed was a stronger base model for AIME-style math as validated by prior work [[11]](#ref11). That let us move on to a more ambitious task: AIME. + ### AIME25: harder math -MATH-500 is becoming saturated for modern models, so we moved to AIME25---the benchmark used by OLMo 3, GLM 4.7, and DeepSeek. -But AIME has only 30 questions, making evaluation extremely noisy: a single question difference shifts Pass@1 by 3%. +MATH-500 was a good initial test bed to validate our pipeline, but it is becoming saturated for modern models. We thus moved to AIME---the benchmark used by OLMo 3, GLM 4.7, and DeepSeek. -We addressed this by implementing a robust combinatorial Pass@k estimator (following the approach from Codex, lighteval, and DeepMath) and increasing the eval sample size K per task to 32 ([PR #2493](https://github.com/marin-community/marin/pull/2493)). +Unfortunately, the AIME has only 30 questions, making evaluation extremely noisy: a single question difference shifts Pass@1 by 3%. We addressed this by implementing a robust combinatorial Pass@k estimator (following the approach from Codex [[12]](#ref12), [lighteval](https://github.com/huggingface/lighteval), and DeepMath [[13]](#ref13)) and increasing the eval sample size K per task to 32 ([PR #2493](https://github.com/marin-community/marin/pull/2493)). Training Qwen 2.5 7B on [DeepMath-103K](https://huggingface.co/datasets/PRIME-RL/DeepMath-103K) showed steady Pass@16 gains (reaching 0.40) but Pass@1 remained near zero after 40 steps ([PR #2441](https://github.com/marin-community/marin/pull/2441)). -This suggests longer training may help---the model is learning but hasn't yet concentrated probability mass on correct solutions. +Our hypothesis is that longer training may help---Pass@1 is noisier and harder than Pass@16, so there may be a threshold the model needs to surpass in Pass@16 before we see stable improvement in Pass@1. - +AIME25 training: Pass@16 steadily improves to 0.40, but Pass@1 remains far from the 0.175 target due to high evaluation variance.
@@ -115,14 +142,12 @@ This suggests longer training may help---the model is learning but hasn't yet co While math is an excellent playground for testing RL, code is the domain that brings the most productivity to the real world and also brings much more complexity from the real world to our RL environments. One important lesson learned while expanding Marin RL to code is to always double-check the verifier logic. Our initial accuracy was falsely ~100% because the evaluation environment executed test scripts without invoking the validation function. - - -After fixing the eval, we reproduced Code-R1's results by training Qwen 2.5 7B Instruct with RL on 2K LeetCode questions ([PR #2286](https://github.com/marin-community/marin/pull/2286)). -HumanEval+ improved from 0.80 to 0.84 in 264 steps, closely matching Code-R1's reported 0.848 ([wandb run](https://wandb.ai/marin-community/marin_post_training/runs/qwen2.5-7bi-1m-code-r1-lr=5e-7-20260112-231710-rollout-0)). One remaining issue is we observe pass@1 started to destablize after 240 steps. This is likely because we omitted the kl divergence used in Code-R1. +After fixing the eval, we reproduced Code-R1's results [[10]](#ref10) by training Qwen 2.5 7B Instruct with RL on 2K LeetCode questions ([PR #2286](https://github.com/marin-community/marin/pull/2286)). +HumanEval+ improved from 0.80 to 0.84 in 264 steps, closely matching Code-R1's reported 0.848 ([wandb run](https://wandb.ai/marin-community/marin_post_training/runs/qwen2.5-7bi-1m-code-r1-lr=5e-7-20260112-231710-rollout-0)). One remaining issue is we observe pass@1 started to destabilize after 240 steps. This is likely because we omitted the KL divergence used in Code-R1 [[10]](#ref10). - + -Code-R1 reproduction results. Marin closely matches the original paper's improvements on HumanEval+ after 264 steps.
+Left: bugged verifier falsely showed ~100% accuracy. Right: after fixing the eval, HumanEval+ Pass@1 improves from 0.80 to 0.84, closely matching Code-R1's reported 0.848 (dashed line). Pass@1 destabilizes after ~240 steps.
## What's next @@ -134,11 +159,11 @@ For now, we are shifting our focus from RL to SFT in preparation for the next Ma ## Five lessons from building an RL pipeline from scratch -1. **Establish baselines first.** Tinker baselines saved weeks of debugging by telling us what to expect. -2. **Base model choice > algorithm tuning.** Llama 1B failed GSM8K while Llama 8B solved it in 1 step. -3. **Reproduce before innovating.** Code-R1 and MATH baselines caught subtle environment and prompt bugs. -4. **Evaluation needs care.** AIME25's 30 questions require multi-seed evals and careful estimators. -5. **Infrastructure is half the battle.** Memory, logging, weight sync, and dependencies required constant attention. +1. **Establish baselines first.** As much as we would have loved to zero-shot an Async RL pipeline, we are grateful for Tinker as the baselines saved weeks of debugging by helping us sanity check our Sync RL pipeline. +2. **Base model choice > algorithm tuning.** Llama 1B failed GSM8K while Llama 8B solved it in 1 step. This highlights the importance of mid-training / pre-training which we will further explore. +3. **Verify your environments end-to-end.** Code-R1 and MATH baselines caught subtle verifier and prompt bugs that would have silently corrupted results. For example, our initial code evaluator reported ~100% accuracy because it ran test scripts without invoking the validation function. +4. **Evaluation needs care.** AIME25's 30 questions make evaluation extremely noisy---a single question shifts Pass@1 by 3%. We initially estimated Pass@k by naively subsampling k trials from a pool of 16, which has high variance for small k. Switching to a combinatorial estimator ([PR #2493](https://github.com/marin-community/marin/pull/2493)) that uses all 16 trials gave us much more stable measurements. We also found that different prompt formatters for MATH-500 can significantly affect results, so consistency matters. +5. **Infrastructure is most of the battle.** The most performant RL algorithms today are quite simple iterations of the policy gradient. On the other hand, managing memory, logging, weight sync, and dependencies correctly and efficiently is a non-trivial systems challenge. ## Acknowledgements @@ -148,3 +173,44 @@ Thanks to Russell for laying out the infra for the RL effort and reviewing the C Thanks to Romain for driving forward the RL PR and suggesting to break it down into more manageable sub-PRs. Thanks to David for reviewing the bf16 weight transfer optimization PR and for giving numerous pointers and guidance during weekly meetings. And thanks to Percy for helping set goals and milestones for RL to keep us on course. + +## Cited Works + + +[1] Sutton, R.S. and Barto, A.G. (2018). [Reinforcement Learning: An Introduction](http://incompleteideas.net/book/the-book-2nd.html), 2nd Edition. MIT Press. + + +[2] Silver, D., Huang, A., Maddison, C. et al. (2016). [Mastering the game of Go with deep neural networks and tree search](https://www.nature.com/articles/nature16961). Nature, 529(7587), 484-489. + + +[3] Bai, Y., Kadavath, S., Kundu, S. et al. (2022). [Constitutional AI: Harmlessness from AI Feedback](https://arxiv.org/abs/2212.08073). arXiv:2212.08073. + + +[4] Ouyang, L., Wu, J., Jiang, X. et al. (2022). [Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155). NeurIPS 2022. arXiv:2203.02155. + + +[5] OpenAI (2024). [OpenAI o1 System Card](https://arxiv.org/abs/2412.16720). arXiv:2412.16720. + + +[6] DeepSeek-AI et al. (2025). [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://arxiv.org/abs/2501.12948). arXiv:2501.12948. + + +[7] Mistral AI et al. (2025). [Magistral](https://arxiv.org/abs/2506.10910). arXiv:2506.10910. + + +[8] Thinking Machines Lab (2025). [LoRA Without Regret](https://thinkingmachines.ai/blog/lora/). + + +[9] Zheng, C. et al. (2025). [Defeating the Training-Inference Mismatch via FP16](https://arxiv.org/abs/2510.26788). arXiv:2510.26788. + + +[10] Liu, J. et al. (2025). [Code-R1: Reproducing R1 for Code with Reliable Rewards](https://github.com/ganler/code-r1). + + +[11] Liu, Z., Chen, Z., Li, J. et al. (2025). [Understanding R1-Zero-Like Training: A Critical Perspective (Dr. GRPO)](https://arxiv.org/abs/2503.20783). COLM 2025. arXiv:2503.20783. + + +[12] Chen, M., Tworek, J., Jun, H. et al. (2021). [Evaluating Large Language Models Trained on Code](https://arxiv.org/abs/2107.03374). arXiv:2107.03374. + + +[13] He, Z. et al. (2025). [DeepMath-103K: A Large-Scale, Challenging, Decontaminated, and Verifiable Mathematical Dataset for Advancing Reasoning](https://arxiv.org/abs/2504.11456). arXiv:2504.11456. diff --git a/assets/images/posts/async-rl-from-scratch/async_divergence.png b/assets/images/posts/async-rl-from-scratch/async_divergence.png index 3d1ca53..3a8b6b0 100644 Binary files a/assets/images/posts/async-rl-from-scratch/async_divergence.png and b/assets/images/posts/async-rl-from-scratch/async_divergence.png differ diff --git a/assets/images/posts/async-rl-from-scratch/code_r1_combined.png b/assets/images/posts/async-rl-from-scratch/code_r1_combined.png new file mode 100644 index 0000000..a43fc2e Binary files /dev/null and b/assets/images/posts/async-rl-from-scratch/code_r1_combined.png differ diff --git a/assets/images/posts/async-rl-from-scratch/deepmath_103k.png b/assets/images/posts/async-rl-from-scratch/deepmath_103k.png new file mode 100644 index 0000000..2676211 Binary files /dev/null and b/assets/images/posts/async-rl-from-scratch/deepmath_103k.png differ diff --git a/assets/images/posts/async-rl-from-scratch/postfix_stability.png b/assets/images/posts/async-rl-from-scratch/postfix_stability.png new file mode 100644 index 0000000..ba2c549 Binary files /dev/null and b/assets/images/posts/async-rl-from-scratch/postfix_stability.png differ diff --git a/assets/images/posts/async-rl-from-scratch/tinker_comparison.png b/assets/images/posts/async-rl-from-scratch/tinker_comparison.png new file mode 100644 index 0000000..0d3f53b Binary files /dev/null and b/assets/images/posts/async-rl-from-scratch/tinker_comparison.png differ