Skip to content

Commit 02fa7b4

Browse files
tianshubThe tunix Authors
authored andcommitted
fix deepscaler notebook
PiperOrigin-RevId: 875752294
1 parent 791d90c commit 02fa7b4

File tree

1 file changed

+46
-25
lines changed

1 file changed

+46
-25
lines changed

examples/deepscaler/train_deepscaler_nb.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33
# [WIP] Reproduction of [Deepscaler](https://pretty-radio-b75.notion.site/DeepScaleR-Surpassing-O1-Preview-with-a-1-5B-Model-by-Scaling-RL-19681902c1468005bed8ca303013a4e2) with Single-turn Agentic framework.
44

55
import contextlib
6+
import logging
7+
import math
68
import os
9+
import sys
710

11+
from absl import logging as absl_logging
812
from flax import nnx
913
import grain
1014
import jax
@@ -14,11 +18,6 @@
1418
from orbax import checkpoint as ocp
1519
import qwix
1620

17-
import math
18-
import logging
19-
import sys
20-
from absl import logging as absl_logging
21-
2221
# ====== Logging Configuration ======
2322
# 1. Force absl to use python logging
2423
absl_logging.use_python_logging()
@@ -29,29 +28,31 @@
2928
level=logging.INFO,
3029
format="%(asctime)s - %(levelname)s - [%(name)s] %(message)s",
3130
datefmt="%Y-%m-%d %H:%M:%S",
32-
force=True
31+
force=True,
3332
)
3433

3534
# 3. Explicitly set levels for relevant loggers
3635
logging.getLogger().setLevel(logging.INFO)
37-
logging.getLogger('absl').setLevel(logging.INFO)
36+
logging.getLogger("absl").setLevel(logging.INFO)
3837

3938
# 4. Set absl verbosity
4039
absl_logging.set_verbosity(absl_logging.INFO)
41-
absl_logging.set_stderrthreshold('info')
40+
absl_logging.set_stderrthreshold("info")
4241

4342
print("Logging configured at INFO level.")
4443

4544
try:
4645
from etils import ecolab
46+
4747
cm = ecolab.adhoc(
4848
source=ecolab.FROM_NOTEBOOK_OR_HEAD,
49-
reload='tunix',
50-
behavior='preferred',
49+
reload="tunix",
50+
behavior="preferred",
5151
cell_autoreload=True,
5252
)
5353
except:
5454
import contextlib
55+
5556
cm = contextlib.nullcontext()
5657

5758
with cm:
@@ -72,6 +73,7 @@
7273

7374
try:
7475
import pathwaysutils
76+
7577
pathwaysutils.initialize()
7678
except:
7779
pass
@@ -119,6 +121,7 @@
119121
# Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, for
120122
# stable updates.
121123
EPSILON = 0.2
124+
EPSILON_HIGH = 0.28
122125

123126
# ====== Training ======
124127
ENABLE_REMAT = True
@@ -135,6 +138,11 @@
135138
# Number of training steps.
136139
MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)
137140

141+
# Max concurrency for parallel processing of trajectories.
142+
MAX_CONCURRENCY = 64
143+
144+
MODEL_DTYPE = jnp.float32
145+
138146
# === AdamW, warmup, cosine scheduler ===
139147
LEARNING_RATE = 1e-6
140148
B1 = 0.9 # Adam beta1
@@ -190,9 +198,9 @@
190198
)
191199

192200
rollout_mesh = jax.sharding.Mesh(
193-
rollout_device_list,
194-
axis_names = ROLLOUT_MESH[1],
195-
axis_types = (jax.sharding.AxisType.Auto,) * len(ROLLOUT_MESH[0]),
201+
rollout_device_list,
202+
axis_names=ROLLOUT_MESH[1],
203+
axis_types=(jax.sharding.AxisType.Auto,) * len(ROLLOUT_MESH[0]),
196204
)
197205
# rollout_mesh = jax.make_mesh(
198206
# *ROLLOUT_MESH,
@@ -209,9 +217,9 @@
209217
# axis_types=(jax.sharding.AxisType.Auto,) * len(TRAINER_MESH[0]),
210218
# )
211219
trainer_mesh = jax.sharding.Mesh(
212-
trainer_devices_list,
213-
axis_names = TRAINER_MESH[1],
214-
axis_types = (jax.sharding.AxisType.Auto,) * len(TRAINER_MESH[0]),
220+
trainer_devices_list,
221+
axis_names=TRAINER_MESH[1],
222+
axis_types=(jax.sharding.AxisType.Auto,) * len(TRAINER_MESH[0]),
215223
)
216224
else:
217225
rollout_mesh = mesh
@@ -220,6 +228,7 @@
220228
# %%
221229
try:
222230
from GOOGLE_INTERNAL_PACKAGE_PATH.pyglib import gfile
231+
223232
file_open = gfile.Open
224233

225234
NOTEBOOK_ENV = "g3"
@@ -259,12 +268,17 @@
259268
AutoTokenizer = transformers.AutoTokenizer
260269

261270

262-
DEEPSCALER_DATA_PATH = os.path.join(DATA_PATH_PREFIX, "DeepScaleR-Preview-Dataset/deepscaler.json")
263-
AIME_2024_DATA_PATH = os.path.join(DATA_PATH_PREFIX, "HuggingFaceH4/aime_2024/train-00000-of-00001.parquet")
271+
DEEPSCALER_DATA_PATH = os.path.join(
272+
DATA_PATH_PREFIX, "DeepScaleR-Preview-Dataset/deepscaler.json"
273+
)
274+
AIME_2024_DATA_PATH = os.path.join(
275+
DATA_PATH_PREFIX, "HuggingFaceH4/aime_2024/train-00000-of-00001.parquet"
276+
)
277+
264278

265279
def create_datasets(
266280
train_ds_path: str = DEEPSCALER_DATA_PATH,
267-
test_ds_path: str = AIME_2024_DATA_PATH
281+
test_ds_path: str = AIME_2024_DATA_PATH,
268282
):
269283
def preprocess_fn(example, index):
270284
return {
@@ -273,7 +287,9 @@ def preprocess_fn(example, index):
273287
"data_source": "math",
274288
}
275289

276-
with file_open(train_ds_path) as train_f, file_open(test_ds_path, 'rb') as test_f:
290+
with file_open(train_ds_path) as train_f, file_open(
291+
test_ds_path, "rb"
292+
) as test_f:
277293
train_df = pd.read_json(train_f)
278294
test_df = pd.read_parquet(test_f)
279295

@@ -290,7 +306,9 @@ def process_item(item):
290306
prompt = f"{question} {instruction}"
291307
prompt = tokenizer.apply_chat_template(
292308
[{"role": "user", "content": prompt}],
293-
tokenize=False, add_generation_prompt=True)
309+
tokenize=False,
310+
add_generation_prompt=True,
311+
)
294312

295313
return {
296314
"prompts": prompt,
@@ -302,6 +320,7 @@ def process_item(item):
302320
test_ds = grain.MapDataset.source(test_ds).map(process_item)
303321
return train_ds, test_ds
304322

323+
305324
# %%
306325

307326
tokenizer_source = MODEL_PATH if NOTEBOOK_ENV == "g3" else MODEL_VERSION
@@ -339,7 +358,7 @@ def process_item(item):
339358

340359
print("MODEL_PATH: ", MODEL_PATH)
341360
qwen2_ref = params_lib.create_model_from_safe_tensors(
342-
MODEL_PATH, config, trainer_mesh, dtype=jnp.bfloat16
361+
MODEL_PATH, config, trainer_mesh, dtype=MODEL_DTYPE
343362
)
344363

345364

@@ -367,12 +386,13 @@ def get_lora_model(base_model, model_mesh):
367386

368387
return lora_model
369388

389+
370390
# %%
371391
if TRAIN_WITH_LORA:
372392
qwen2_actor = get_lora_model(qwen2_ref, trainer_mesh)
373393
else:
374394
qwen2_actor = params_lib.create_model_from_safe_tensors(
375-
MODEL_PATH, config, trainer_mesh, dtype=jnp.float32
395+
MODEL_PATH, config, trainer_mesh, dtype=MODEL_DTYPE
376396
)
377397

378398
# %%
@@ -446,7 +466,7 @@ def get_lora_model(base_model, model_mesh):
446466
"rollout_sglang_jax_disable_radix_cache": True,
447467
"rollout_sglang_jax_enable_deterministic_sampling": False,
448468
"rollout_sglang_jax_chunked_prefill_size": 2048,
449-
"rollout_sglang_jax_max_running_requests": 32,
469+
"rollout_sglang_jax_max_running_requests": MAX_CONCURRENCY,
450470
"rollout_sglang_jax_page_size": 128,
451471
}
452472

@@ -509,8 +529,9 @@ def get_lora_model(base_model, model_mesh):
509529
max_response_length=MAX_RESPONSE_LENGTH,
510530
beta=BETA,
511531
epsilon=EPSILON,
532+
epsilon_high=EPSILON_HIGH,
512533
system_prompt="",
513-
max_concurrency=64,
534+
max_concurrency=MAX_CONCURRENCY,
514535
)
515536

516537
# %%

0 commit comments

Comments
 (0)