Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 46 additions & 25 deletions examples/deepscaler/train_deepscaler_nb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
# [WIP] Reproduction of [Deepscaler](https://pretty-radio-b75.notion.site/DeepScaleR-Surpassing-O1-Preview-with-a-1-5B-Model-by-Scaling-RL-19681902c1468005bed8ca303013a4e2) with Single-turn Agentic framework.

import contextlib
import logging
import math
import os
import sys

from absl import logging as absl_logging
from flax import nnx
import grain
import jax
Expand All @@ -14,11 +18,6 @@
from orbax import checkpoint as ocp
import qwix

import math
import logging
import sys
from absl import logging as absl_logging

# ====== Logging Configuration ======
# 1. Force absl to use python logging
absl_logging.use_python_logging()
Expand All @@ -29,29 +28,31 @@
level=logging.INFO,
format="%(asctime)s - %(levelname)s - [%(name)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
force=True
force=True,
)

# 3. Explicitly set levels for relevant loggers
logging.getLogger().setLevel(logging.INFO)
logging.getLogger('absl').setLevel(logging.INFO)
logging.getLogger("absl").setLevel(logging.INFO)

# 4. Set absl verbosity
absl_logging.set_verbosity(absl_logging.INFO)
absl_logging.set_stderrthreshold('info')
absl_logging.set_stderrthreshold("info")

print("Logging configured at INFO level.")

try:
from etils import ecolab

cm = ecolab.adhoc(
source=ecolab.FROM_NOTEBOOK_OR_HEAD,
reload='tunix',
behavior='preferred',
reload="tunix",
behavior="preferred",
cell_autoreload=True,
)
except:
import contextlib

cm = contextlib.nullcontext()

with cm:
Expand All @@ -72,6 +73,7 @@

try:
import pathwaysutils

pathwaysutils.initialize()
except:
pass
Expand Down Expand Up @@ -119,6 +121,7 @@
# Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, for
# stable updates.
EPSILON = 0.2
EPSILON_HIGH = 0.28

# ====== Training ======
ENABLE_REMAT = True
Expand All @@ -135,6 +138,11 @@
# Number of training steps.
MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)

# Max concurrency for parallel processing of trajectories.
MAX_CONCURRENCY = 64

MODEL_DTYPE = jnp.float32

# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 1e-6
B1 = 0.9 # Adam beta1
Expand Down Expand Up @@ -190,9 +198,9 @@
)

rollout_mesh = jax.sharding.Mesh(
rollout_device_list,
axis_names = ROLLOUT_MESH[1],
axis_types = (jax.sharding.AxisType.Auto,) * len(ROLLOUT_MESH[0]),
rollout_device_list,
axis_names=ROLLOUT_MESH[1],
axis_types=(jax.sharding.AxisType.Auto,) * len(ROLLOUT_MESH[0]),
)
# rollout_mesh = jax.make_mesh(
# *ROLLOUT_MESH,
Expand All @@ -209,9 +217,9 @@
# axis_types=(jax.sharding.AxisType.Auto,) * len(TRAINER_MESH[0]),
# )
trainer_mesh = jax.sharding.Mesh(
trainer_devices_list,
axis_names = TRAINER_MESH[1],
axis_types = (jax.sharding.AxisType.Auto,) * len(TRAINER_MESH[0]),
trainer_devices_list,
axis_names=TRAINER_MESH[1],
axis_types=(jax.sharding.AxisType.Auto,) * len(TRAINER_MESH[0]),
)
else:
rollout_mesh = mesh
Expand All @@ -220,6 +228,7 @@
# %%
try:
from GOOGLE_INTERNAL_PACKAGE_PATH.pyglib import gfile

file_open = gfile.Open

NOTEBOOK_ENV = "g3"
Expand Down Expand Up @@ -259,12 +268,17 @@
AutoTokenizer = transformers.AutoTokenizer


DEEPSCALER_DATA_PATH = os.path.join(DATA_PATH_PREFIX, "DeepScaleR-Preview-Dataset/deepscaler.json")
AIME_2024_DATA_PATH = os.path.join(DATA_PATH_PREFIX, "HuggingFaceH4/aime_2024/train-00000-of-00001.parquet")
DEEPSCALER_DATA_PATH = os.path.join(
DATA_PATH_PREFIX, "DeepScaleR-Preview-Dataset/deepscaler.json"
)
AIME_2024_DATA_PATH = os.path.join(
DATA_PATH_PREFIX, "HuggingFaceH4/aime_2024/train-00000-of-00001.parquet"
)


def create_datasets(
train_ds_path: str = DEEPSCALER_DATA_PATH,
test_ds_path: str = AIME_2024_DATA_PATH
test_ds_path: str = AIME_2024_DATA_PATH,
):
def preprocess_fn(example, index):
return {
Expand All @@ -273,7 +287,9 @@ def preprocess_fn(example, index):
"data_source": "math",
}

with file_open(train_ds_path) as train_f, file_open(test_ds_path, 'rb') as test_f:
with file_open(train_ds_path) as train_f, file_open(
test_ds_path, "rb"
) as test_f:
train_df = pd.read_json(train_f)
test_df = pd.read_parquet(test_f)

Expand All @@ -290,7 +306,9 @@ def process_item(item):
prompt = f"{question} {instruction}"
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False, add_generation_prompt=True)
tokenize=False,
add_generation_prompt=True,
)

return {
"prompts": prompt,
Expand All @@ -302,6 +320,7 @@ def process_item(item):
test_ds = grain.MapDataset.source(test_ds).map(process_item)
return train_ds, test_ds


# %%

tokenizer_source = MODEL_PATH if NOTEBOOK_ENV == "g3" else MODEL_VERSION
Expand Down Expand Up @@ -339,7 +358,7 @@ def process_item(item):

print("MODEL_PATH: ", MODEL_PATH)
qwen2_ref = params_lib.create_model_from_safe_tensors(
MODEL_PATH, config, trainer_mesh, dtype=jnp.bfloat16
MODEL_PATH, config, trainer_mesh, dtype=MODEL_DTYPE
)


Expand Down Expand Up @@ -367,12 +386,13 @@ def get_lora_model(base_model, model_mesh):

return lora_model


# %%
if TRAIN_WITH_LORA:
qwen2_actor = get_lora_model(qwen2_ref, trainer_mesh)
else:
qwen2_actor = params_lib.create_model_from_safe_tensors(
MODEL_PATH, config, trainer_mesh, dtype=jnp.float32
MODEL_PATH, config, trainer_mesh, dtype=MODEL_DTYPE
)

# %%
Expand Down Expand Up @@ -446,7 +466,7 @@ def get_lora_model(base_model, model_mesh):
"rollout_sglang_jax_disable_radix_cache": True,
"rollout_sglang_jax_enable_deterministic_sampling": False,
"rollout_sglang_jax_chunked_prefill_size": 2048,
"rollout_sglang_jax_max_running_requests": 32,
"rollout_sglang_jax_max_running_requests": MAX_CONCURRENCY,
"rollout_sglang_jax_page_size": 128,
}

Expand Down Expand Up @@ -509,8 +529,9 @@ def get_lora_model(base_model, model_mesh):
max_response_length=MAX_RESPONSE_LENGTH,
beta=BETA,
epsilon=EPSILON,
epsilon_high=EPSILON_HIGH,
system_prompt="",
max_concurrency=64,
max_concurrency=MAX_CONCURRENCY,
)

# %%
Expand Down