33# [WIP] Reproduction of [Deepscaler](https://pretty-radio-b75.notion.site/DeepScaleR-Surpassing-O1-Preview-with-a-1-5B-Model-by-Scaling-RL-19681902c1468005bed8ca303013a4e2) with Single-turn Agentic framework.
44
55import contextlib
6+ import logging
7+ import math
68import os
9+ import sys
710
11+ from absl import logging as absl_logging
812from flax import nnx
913import grain
1014import jax
1418from orbax import checkpoint as ocp
1519import qwix
1620
17- import math
18- import logging
19- import sys
20- from absl import logging as absl_logging
21-
2221# ====== Logging Configuration ======
2322# 1. Force absl to use python logging
2423absl_logging .use_python_logging ()
2928 level = logging .INFO ,
3029 format = "%(asctime)s - %(levelname)s - [%(name)s] %(message)s" ,
3130 datefmt = "%Y-%m-%d %H:%M:%S" ,
32- force = True
31+ force = True ,
3332)
3433
3534# 3. Explicitly set levels for relevant loggers
3635logging .getLogger ().setLevel (logging .INFO )
37- logging .getLogger (' absl' ).setLevel (logging .INFO )
36+ logging .getLogger (" absl" ).setLevel (logging .INFO )
3837
3938# 4. Set absl verbosity
4039absl_logging .set_verbosity (absl_logging .INFO )
41- absl_logging .set_stderrthreshold (' info' )
40+ absl_logging .set_stderrthreshold (" info" )
4241
4342print ("Logging configured at INFO level." )
4443
4544try :
4645 from etils import ecolab
46+
4747 cm = ecolab .adhoc (
4848 source = ecolab .FROM_NOTEBOOK_OR_HEAD ,
49- reload = ' tunix' ,
50- behavior = ' preferred' ,
49+ reload = " tunix" ,
50+ behavior = " preferred" ,
5151 cell_autoreload = True ,
5252 )
5353except :
5454 import contextlib
55+
5556 cm = contextlib .nullcontext ()
5657
5758with cm :
7273
7374try :
7475 import pathwaysutils
76+
7577 pathwaysutils .initialize ()
7678except :
7779 pass
119121# Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, for
120122# stable updates.
121123EPSILON = 0.2
124+ EPSILON_HIGH = 0.28
122125
123126# ====== Training ======
124127ENABLE_REMAT = True
135138# Number of training steps.
136139MAX_STEPS = int (NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS )
137140
141+ # Max concurrency for parallel processing of trajectories.
142+ MAX_CONCURRENCY = 64
143+
144+ MODEL_DTYPE = jnp .float32
145+
138146# === AdamW, warmup, cosine scheduler ===
139147LEARNING_RATE = 1e-6
140148B1 = 0.9 # Adam beta1
190198 )
191199
192200 rollout_mesh = jax .sharding .Mesh (
193- rollout_device_list ,
194- axis_names = ROLLOUT_MESH [1 ],
195- axis_types = (jax .sharding .AxisType .Auto ,) * len (ROLLOUT_MESH [0 ]),
201+ rollout_device_list ,
202+ axis_names = ROLLOUT_MESH [1 ],
203+ axis_types = (jax .sharding .AxisType .Auto ,) * len (ROLLOUT_MESH [0 ]),
196204 )
197205 # rollout_mesh = jax.make_mesh(
198206 # *ROLLOUT_MESH,
209217 # axis_types=(jax.sharding.AxisType.Auto,) * len(TRAINER_MESH[0]),
210218 # )
211219 trainer_mesh = jax .sharding .Mesh (
212- trainer_devices_list ,
213- axis_names = TRAINER_MESH [1 ],
214- axis_types = (jax .sharding .AxisType .Auto ,) * len (TRAINER_MESH [0 ]),
220+ trainer_devices_list ,
221+ axis_names = TRAINER_MESH [1 ],
222+ axis_types = (jax .sharding .AxisType .Auto ,) * len (TRAINER_MESH [0 ]),
215223 )
216224else :
217225 rollout_mesh = mesh
220228# %%
221229try :
222230 from GOOGLE_INTERNAL_PACKAGE_PATH .pyglib import gfile
231+
223232 file_open = gfile .Open
224233
225234 NOTEBOOK_ENV = "g3"
259268AutoTokenizer = transformers .AutoTokenizer
260269
261270
262- DEEPSCALER_DATA_PATH = os .path .join (DATA_PATH_PREFIX , "DeepScaleR-Preview-Dataset/deepscaler.json" )
263- AIME_2024_DATA_PATH = os .path .join (DATA_PATH_PREFIX , "HuggingFaceH4/aime_2024/train-00000-of-00001.parquet" )
271+ DEEPSCALER_DATA_PATH = os .path .join (
272+ DATA_PATH_PREFIX , "DeepScaleR-Preview-Dataset/deepscaler.json"
273+ )
274+ AIME_2024_DATA_PATH = os .path .join (
275+ DATA_PATH_PREFIX , "HuggingFaceH4/aime_2024/train-00000-of-00001.parquet"
276+ )
277+
264278
265279def create_datasets (
266280 train_ds_path : str = DEEPSCALER_DATA_PATH ,
267- test_ds_path : str = AIME_2024_DATA_PATH
281+ test_ds_path : str = AIME_2024_DATA_PATH ,
268282):
269283 def preprocess_fn (example , index ):
270284 return {
@@ -273,7 +287,9 @@ def preprocess_fn(example, index):
273287 "data_source" : "math" ,
274288 }
275289
276- with file_open (train_ds_path ) as train_f , file_open (test_ds_path , 'rb' ) as test_f :
290+ with file_open (train_ds_path ) as train_f , file_open (
291+ test_ds_path , "rb"
292+ ) as test_f :
277293 train_df = pd .read_json (train_f )
278294 test_df = pd .read_parquet (test_f )
279295
@@ -290,7 +306,9 @@ def process_item(item):
290306 prompt = f"{ question } { instruction } "
291307 prompt = tokenizer .apply_chat_template (
292308 [{"role" : "user" , "content" : prompt }],
293- tokenize = False , add_generation_prompt = True )
309+ tokenize = False ,
310+ add_generation_prompt = True ,
311+ )
294312
295313 return {
296314 "prompts" : prompt ,
@@ -302,6 +320,7 @@ def process_item(item):
302320 test_ds = grain .MapDataset .source (test_ds ).map (process_item )
303321 return train_ds , test_ds
304322
323+
305324# %%
306325
307326tokenizer_source = MODEL_PATH if NOTEBOOK_ENV == "g3" else MODEL_VERSION
@@ -339,7 +358,7 @@ def process_item(item):
339358
340359print ("MODEL_PATH: " , MODEL_PATH )
341360qwen2_ref = params_lib .create_model_from_safe_tensors (
342- MODEL_PATH , config , trainer_mesh , dtype = jnp . bfloat16
361+ MODEL_PATH , config , trainer_mesh , dtype = MODEL_DTYPE
343362)
344363
345364
@@ -367,12 +386,13 @@ def get_lora_model(base_model, model_mesh):
367386
368387 return lora_model
369388
389+
370390# %%
371391if TRAIN_WITH_LORA :
372392 qwen2_actor = get_lora_model (qwen2_ref , trainer_mesh )
373393else :
374394 qwen2_actor = params_lib .create_model_from_safe_tensors (
375- MODEL_PATH , config , trainer_mesh , dtype = jnp . float32
395+ MODEL_PATH , config , trainer_mesh , dtype = MODEL_DTYPE
376396 )
377397
378398# %%
@@ -446,7 +466,7 @@ def get_lora_model(base_model, model_mesh):
446466 "rollout_sglang_jax_disable_radix_cache" : True ,
447467 "rollout_sglang_jax_enable_deterministic_sampling" : False ,
448468 "rollout_sglang_jax_chunked_prefill_size" : 2048 ,
449- "rollout_sglang_jax_max_running_requests" : 32 ,
469+ "rollout_sglang_jax_max_running_requests" : MAX_CONCURRENCY ,
450470 "rollout_sglang_jax_page_size" : 128 ,
451471}
452472
@@ -509,8 +529,9 @@ def get_lora_model(base_model, model_mesh):
509529 max_response_length = MAX_RESPONSE_LENGTH ,
510530 beta = BETA ,
511531 epsilon = EPSILON ,
532+ epsilon_high = EPSILON_HIGH ,
512533 system_prompt = "" ,
513- max_concurrency = 64 ,
534+ max_concurrency = MAX_CONCURRENCY ,
514535)
515536
516537# %%
0 commit comments