Skip to content

Commit de3704f

Browse files
Support inflight weight updates in RL (#2325)
Adds support for inflight weight updates in RL workers. --------- Co-authored-by: Christopher Chou <49086305+BabyChouSr@users.noreply.github.com>
1 parent 96c1669 commit de3704f

44 files changed

Lines changed: 3857 additions & 815 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

experiments/exp1743_rl_math.py

Lines changed: 0 additions & 239 deletions
This file was deleted.

experiments/exp2039_rl_math500.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright 2025 The Marin Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# nodryrun because vLLM is not installed by default
15+
16+
import datetime
17+
import logging
18+
import os
19+
20+
from levanter.models.llama import LlamaConfig
21+
from marin.execution.executor import executor_main
22+
from marin.rl.curriculum import CurriculumConfig, LessonConfig, SamplingParams
23+
from marin.rl.environments import EnvConfig
24+
from marin.rl.rl_losses import RLOOLoss
25+
26+
from marin.rl.rl_experiment_utils import (
27+
ModelConfig,
28+
RLExperimentConfig,
29+
make_rl_step,
30+
)
31+
32+
logger = logging.getLogger(__name__)
33+
34+
35+
llama_3_1_8b = ModelConfig(
36+
name="meta-llama/Llama-3.1-8B-Instruct",
37+
type="llama",
38+
tokenizer="meta-llama/Llama-3.1-8B-Instruct",
39+
checkpoint="meta-llama/Llama-3.1-8B-Instruct",
40+
config_class=LlamaConfig,
41+
)
42+
43+
44+
def create_math_curriculum(run_id: str, experiment_config: RLExperimentConfig) -> CurriculumConfig:
45+
"""Create progressive math curriculum: comparison -> easy -> medium -> hard."""
46+
47+
default_sampling = SamplingParams(
48+
temperature=1.0,
49+
n_prompts=experiment_config.n_prompts, # Overdo it since we know there are some with no signal?
50+
n_generations_per_prompt=experiment_config.n_generations_per_prompt,
51+
max_output_tokens=experiment_config.max_output_tokens,
52+
top_k=4096,
53+
stop_tokens=None,
54+
)
55+
56+
lessons = {
57+
"math_full": LessonConfig(
58+
lesson_id="math_full",
59+
env_config=EnvConfig(
60+
env_class="marin.rl.environments.math_env.MathEnv",
61+
env_args={"seed": 42},
62+
),
63+
dependencies=[],
64+
sampling_params=default_sampling,
65+
),
66+
}
67+
68+
return CurriculumConfig(
69+
lessons=lessons,
70+
eval_frequency=1, # Run full eval after every step
71+
micro_eval_frequency=9999999, # Effectively disable micro-eval
72+
actor_name=f"curriculum-{run_id}",
73+
eval_n_examples=500, # for math500
74+
max_seq_len=experiment_config.max_input_tokens + experiment_config.max_output_tokens,
75+
)
76+
77+
78+
def main():
79+
if os.getenv("CI", None) is not None:
80+
logger.info("Skipping experiment execution on CI environment, needs HF access.")
81+
return
82+
83+
llama_8b = RLExperimentConfig(
84+
model_config=llama_3_1_8b,
85+
rl_loss=RLOOLoss(
86+
kl_coef=0.0,
87+
clip_epsilon_low=0.2,
88+
clip_epsilon_high=0.28,
89+
synchronous=True,
90+
do_trainer_inference_mismatch_importance_sampling=True,
91+
tis_importance_sampling_ratio_max=2.0,
92+
do_overlong_filtering=True,
93+
vocab_tile_size=32064,
94+
),
95+
experiment_name_suffix="math-lr=2e-6-bs=1024",
96+
train_batch_size=1024,
97+
per_device_parallelism=16,
98+
learning_rate=2e-6,
99+
max_input_tokens=1024,
100+
max_output_tokens=1024,
101+
n_prompts=64,
102+
n_generations_per_prompt=16,
103+
inflight_weight_updates=True,
104+
max_rollout_step_delay=1,
105+
)
106+
107+
experiment_configs = [llama_8b]
108+
experiments = []
109+
datestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
110+
for experiment_config in experiment_configs:
111+
model_base_name = experiment_config.model_config.name.split("/")[-1].lower()
112+
model_base_name = model_base_name.replace("-instruct", "i")
113+
114+
# Always include timestamp to avoid cache collisions between runs
115+
name = f"{model_base_name}-{experiment_config.experiment_name_suffix}-{datestamp}"
116+
117+
curriculum = create_math_curriculum(name, experiment_config)
118+
119+
experiments.append(
120+
make_rl_step(
121+
name=name,
122+
config=experiment_config,
123+
curriculum=curriculum,
124+
),
125+
)
126+
127+
executor_main(
128+
steps=experiments,
129+
description="Async RL math training experiments",
130+
)
131+
132+
133+
if __name__ == "__main__":
134+
main()

0 commit comments

Comments
 (0)