Skip to content

Commit bc58ab8

Browse files
pc0618dlwh
andauthored
Add speedrun/profiling helpers for OLMoE + Mixtral (#2249)
Adds speedrun/profiling entrypoints used to reproduce OLMoE + Mixtral MFU experiments and to A/B MoE kernels (ragged-dot vs grouped-matmul / GMM) without touching core training code. ## What’s Included ### Core drivers (intended to be reused) - `experiments/speedrun/custom_mixtral.py` - Custom Mixtral implementation + `CustomMixtralConfig` (registered as `custom_mixtral`, aliased as `MixtralConfig`). - Supports `use_gmm=True` (grouped-matmul MoE experts via `hnn.MoELinear`) vs `use_gmm=False` (ragged-dot MoE experts). - Uses `jax.experimental.shard_map` for shard-safe routing/permutation and wraps results with `hax.named(...)` so axis metadata stays consistent. - Exposes attention knobs (`use_flash_attention`, `attn_backend`, `flash_attention_block_size`) and defaults `flash_attention_block_size=None` to let splash choose block sizes. - HF interoperability via `HFCheckpointConverter` and `to_hf_config`/`from_hf_config`. - `experiments/speedrun/olmoe_1b7b_nemotron_40b.py` - CLI launcher with two presets: - `olmoe_1b7b`: OLMoE-style MoE geometry initialized from `allenai/OLMoE-1B-7B-0125` (reference checkpoint + tokenizer). - `mixtral_8x7b`: Mixtral 8x7B geometry (aligned with MaxText’s model geometry). - `nemotron_only_speedrun(...)` variant skips Paloma validation mixture to avoid rebuilding Paloma caches during MFU runs. - Optional flags to *append* selected XLA flags onto baseline `LIBTPU_INIT_ARGS` from cluster config for quick A/B: - `--append-ici-ag-pipelining-flags` - `--append-async-collective-permute-flag` - `experiments/speedrun/olmoe_eval.py` - Lightweight eval runner for HF-exported checkpoints. - Takes `--model-path gs://.../hf/step-.../` and runs `CORE_TASKS_PLUS_MMLU` via the executor on a TPU slice. ### Short baseline runs / comparisons - `experiments/speedrun/mixtral_8x7b_ragged_run.py` - `experiments/speedrun/mixtral_8x7b_ragged_run_seq4096.py` - Ragged-dot (non-GMM) Mixtral 8x7B baselines for quick sanity/perf checks. ### Debug/sweep scripts (primarily for development) - `experiments/speedrun/pranshu_llama_75m_run.py` - Small Llama baseline speedrun for sanity checks. - `experiments/speedrun/pranshu_mixtral_moe_*` - A set of experiments to compare ragged-dot vs grouped-matmul MoE, including a small debug run and small parameter sweeps. - These are intentionally explicit and verbose to make it easy to reproduce A/B comparisons. ### Launch wrappers - `scripts/run_olmoe_v5p64_profile.sh` - `scripts/run_maxtext_mixtral_profile.sh` - Convenience wrappers for launching profiling/MFU runs through `marin.run.ray_run`. - Designed to rely on env vars for secrets (`WANDB_API_KEY`, `HF_TOKEN`, etc.) rather than hardcoding. ## Notes / Operational Guidance - These drivers are designed to be launched via `uv run python -m marin.run.ray_run ... -- python <driver>.py ...`. - The custom Mixtral implementation uses shard-safe routing/permutation (`shard_map`) to avoid shape/axis mismatches when running inside sharded transforms. - Most runs set auxiliary MoE losses (`lbl_coef` / `rzl_coef`) to `None` when doing pure performance profiling to avoid extra logging/overhead. --------- Co-authored-by: David Hall <david.hall@openathena.ai>
1 parent 995f167 commit bc58ab8

18 files changed

Lines changed: 2209 additions & 2 deletions

experiments/speedrun/custom_mixtral.py

Lines changed: 699 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#!/usr/bin/env bash
2+
# Example run (from repo root, with tmp dirs on local SSD):
3+
# TMPDIR=/tmp RAY_TMPDIR=/tmp ./experiments/speedrun/mixtral/scripts/run_maxtext_mixtral_profile.sh
4+
# Run MaxText Mixtral profiling on v5p-64 via Marin's Ray launcher.
5+
6+
set -euo pipefail
7+
8+
REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../../../.." && pwd)"
9+
cd "$REPO_ROOT"
10+
11+
VENV_PATH="${REPO_ROOT}/maxtext_marin"
12+
if [[ ! -d "${VENV_PATH}" || ! -f "${VENV_PATH}/bin/activate" ]]; then
13+
echo "Missing virtualenv at ${VENV_PATH}. Please run 'uv venv --python 3.12 maxtext_marin' first." >&2
14+
exit 1
15+
fi
16+
17+
if [[ ! -d "submodules/maxtext" ]]; then
18+
echo "Expected MaxText checkout under submodules/maxtext. Please follow docs/tutorials/co-develop.md." >&2
19+
exit 1
20+
fi
21+
22+
source "${VENV_PATH}/bin/activate"
23+
24+
uv pip install -e lib/marin >/dev/null
25+
uv pip install -e lib/levanter >/dev/null
26+
27+
RUN_SUFFIX="$(date +%Y%m%d%H%M%S)"
28+
RUN_NAME="maxtext_mixtral_profile_${RUN_SUFFIX}"
29+
OUTPUT_GCS="gs://marin-us-central1/maxtext/profiles/${RUN_NAME}"
30+
31+
REMOTE_CMD=$(cat <<'EOF'
32+
set -euo pipefail
33+
cd submodules/maxtext
34+
export RUN_PREFLIGHT=false
35+
export REPO_ROOT="$(pwd)/.."
36+
export PYTHONPATH="$(pwd)/src:${REPO_ROOT}/lib/marin/src:${REPO_ROOT}/lib/levanter/src:${REPO_ROOT}/experiments"
37+
export LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=81920 --xla_enable_async_all_gather=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true"
38+
REQ_FILE=/tmp/maxtext_profile_requirements.txt
39+
cat <<'REQ' > "${REQ_FILE}"
40+
absl-py==2.3.1
41+
aqtp==0.9.0
42+
datasets==3.6.0
43+
flax==0.10.0
44+
gcsfs==2025.3.0
45+
grain[parquet]==0.2.9
46+
huggingface_hub==0.36.0
47+
jsonlines==4.0.0
48+
ml-collections==1.1.0
49+
ml-goodput-measurement==0.0.15
50+
omegaconf==2.3.0
51+
optax==0.2.6
52+
orbax-checkpoint==0.11.26
53+
pathwaysutils==0.1.3
54+
protobuf==4.23.4
55+
pydantic==2.11.10
56+
sentencepiece==0.2.1
57+
tensorboard==2.15.1
58+
tensorboard-data-server==0.7.2
59+
tensorboard-plugin-profile==2.15.0
60+
tensorboardx==2.6.4
61+
tokenizers==0.22.1
62+
tiktoken==0.12.0
63+
transformers==4.57.1
64+
cloud-tpu-diagnostics==0.1.5
65+
werkzeug==3.1.3
66+
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
67+
REQ
68+
python3 -m pip install --upgrade pip
69+
python3 -m pip uninstall -y cloud-accelerator-diagnostics google-cloud-aiplatform google-cloud-bigquery google-cloud-resource-manager google-cloud-storage shapely || true
70+
python3 -m pip install --force-reinstall \
71+
numpy==1.26.4 \
72+
ml_dtypes==0.5.0
73+
python3 -m pip install --force-reinstall \
74+
"jax[tpu]==0.6.2" \
75+
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
76+
python3 -m pip install --requirement "${REQ_FILE}"
77+
python3 - <<'PY'
78+
import jax
79+
print("JAX OK", jax.__version__, "PJRT backend:", jax.default_backend())
80+
try:
81+
from jaxlib import xla_extension
82+
except ImportError:
83+
from jax._src.lib import _jax as xla_extension
84+
print("DistributedRuntimeClient?", hasattr(xla_extension, "DistributedRuntimeClient"))
85+
PY
86+
python3 -m MaxText.train src/MaxText/configs/base.yml \
87+
model_name=mixtral-8x7b \
88+
steps=40 \
89+
per_device_batch_size=32 \
90+
enable_checkpointing=false \
91+
remat_policy=full \
92+
ici_fsdp_parallelism=-1 \
93+
max_target_length=1024 \
94+
base_output_directory=${OUTPUT_PATH} \
95+
run_name=${RUN_NAME} \
96+
dataset_type=synthetic \
97+
reuse_example_batch=1 \
98+
gcs_metrics=true \
99+
profiler=xplane \
100+
skip_first_n_steps_for_profiler=10 \
101+
profiler_steps=10 \
102+
upload_all_profiler_results=False \
103+
attention=dot_product \
104+
enable_nnx=false \
105+
sa_block_q=1024 \
106+
sa_block_q_dkv=2048 \
107+
sa_block_q_dq=2048
108+
EOF
109+
)
110+
111+
UV_PROJECT_ENV="${VENV_PATH}" uv run python -m marin.run.ray_run \
112+
--cluster "infra/marin-us-central1.yaml" \
113+
--extra tpu \
114+
--env_vars WANDB_MODE online \
115+
--env_vars WANDB_API_KEY "${WANDB_API_KEY:-}" \
116+
--env_vars WANDB_ENTITY "${WANDB_ENTITY:-}" \
117+
--env_vars WANDB_PROJECT "${WANDB_PROJECT:-}" \
118+
--env_vars HF_TOKEN "${HF_TOKEN:-}" \
119+
--env_vars PIP_NO_CACHE_DIR 1 \
120+
--env_vars RAY_TMPDIR /tmp \
121+
--env_vars TMPDIR /tmp \
122+
--env_vars MAXTEXT_REPO_ROOT submodules/maxtext \
123+
--env_vars OUTPUT_PATH "${OUTPUT_GCS}" \
124+
--env_vars RUN_NAME "${RUN_NAME}" \
125+
-- \
126+
bash -lc "${REMOTE_CMD}"
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2025 The Marin Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# nodryrun
16+
import logging
17+
18+
from experiments.pretraining_datasets import NEMOTRON_WEIGHTS, tokenize_nemotron
19+
from experiments.llama import llama3_tokenizer
20+
from experiments.speedrun.custom_mixtral import MixtralConfig
21+
from experiments.simple_train_config import SimpleTrainConfig
22+
from marin.execution.executor import executor_main
23+
from marin.processing.tokenize import lm_mixture_data_config
24+
from fray.cluster import ResourceConfig
25+
from marin.speedrun.speedrun import Author, SpeedrunConfig, default_speedrun
26+
27+
logger = logging.getLogger("ray")
28+
29+
nemotron_cc_steps = tokenize_nemotron(tokenizer=llama3_tokenizer)
30+
nemotron_cc_mixture = lm_mixture_data_config(
31+
components=nemotron_cc_steps,
32+
weights=NEMOTRON_WEIGHTS,
33+
permutation_type="linear",
34+
)
35+
36+
# Full Mixtral 8x7B configuration using ragged-dot MoE kernels.
37+
mixtral_8x7b_ragged = MixtralConfig(
38+
seq_len=512,
39+
hidden_dim=4096,
40+
intermediate_dim=14336,
41+
num_layers=32,
42+
num_heads=32,
43+
num_kv_heads=8,
44+
n_routed_experts=8,
45+
num_experts_per_tok=2,
46+
gradient_checkpointing=True,
47+
scan_layers=True,
48+
use_gmm=False, # stick with ragged-dot experts
49+
cross_entropy_block_size=32000,
50+
lbl_coef=None,
51+
rzl_coef=None,
52+
)
53+
54+
speedrun_config = SpeedrunConfig(
55+
author=Author(
56+
name="Marin Team",
57+
affiliation="Marin Project",
58+
url=None,
59+
),
60+
description="Train a Mixtral 8x7B MoE model (ragged dot) for 20 steps on Nemotron-CC tokens.",
61+
model_config=mixtral_8x7b_ragged,
62+
train_config=SimpleTrainConfig(
63+
resources=ResourceConfig.with_tpu(tpu_type="v5p-64"),
64+
train_batch_size=32,
65+
num_train_steps=20,
66+
learning_rate=3e-4,
67+
weight_decay=0.1,
68+
steps_per_eval=20,
69+
steps_per_export=20,
70+
),
71+
tokenized_dataset=nemotron_cc_mixture,
72+
)
73+
74+
if __name__ == "__main__":
75+
logger.info("Launching Mixtral 8x7B ragged-dot speedrun.")
76+
logger.info(
77+
"Settings: batch_size=%s, seq_len=%s, steps=%s, cross_entropy_block_size=%s",
78+
speedrun_config.train_config.train_batch_size,
79+
mixtral_8x7b_ragged.seq_len,
80+
speedrun_config.train_config.num_train_steps,
81+
mixtral_8x7b_ragged.cross_entropy_block_size,
82+
)
83+
executor_main(steps=default_speedrun("mixtral_8x7b_ragged_speedrun_bs32_seq512_v5p-64", speedrun_config))
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2025 The Marin Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# nodryrun
16+
import logging
17+
18+
from experiments.llama import llama3_tokenizer
19+
from experiments.pretraining_datasets import NEMOTRON_WEIGHTS, tokenize_nemotron
20+
from experiments.simple_train_config import SimpleTrainConfig
21+
from experiments.speedrun.custom_mixtral import MixtralConfig
22+
from fray.cluster import ResourceConfig
23+
from marin.execution.executor import executor_main
24+
from marin.processing.tokenize import lm_mixture_data_config
25+
from marin.speedrun.speedrun import Author, SpeedrunConfig, default_speedrun
26+
27+
logger = logging.getLogger("ray")
28+
29+
nemotron_cc_steps = tokenize_nemotron(tokenizer=llama3_tokenizer)
30+
nemotron_cc_mixture = lm_mixture_data_config(
31+
components=nemotron_cc_steps,
32+
weights=NEMOTRON_WEIGHTS,
33+
permutation_type="linear",
34+
)
35+
36+
# Full Mixtral 8x7B configuration using ragged-dot MoE kernels.
37+
#
38+
# This variant targets sequence length 4096 to match the MaxText Mixtral v5p benchmarks more closely.
39+
mixtral_8x7b_ragged = MixtralConfig(
40+
seq_len=4096,
41+
hidden_dim=4096,
42+
intermediate_dim=14336,
43+
num_layers=32,
44+
num_heads=32,
45+
num_kv_heads=8,
46+
n_routed_experts=8,
47+
num_experts_per_tok=2,
48+
gradient_checkpointing=True,
49+
scan_layers=True,
50+
use_gmm=False, # stick with ragged-dot experts
51+
cross_entropy_block_size=32000,
52+
lbl_coef=None,
53+
rzl_coef=None,
54+
)
55+
56+
speedrun_config = SpeedrunConfig(
57+
author=Author(
58+
name="Marin Team",
59+
affiliation="Marin Project",
60+
url=None,
61+
),
62+
description=("Train a Mixtral 8x7B MoE model (ragged dot) for 20 steps on Nemotron-CC tokens with seq_len=4096."),
63+
model_config=mixtral_8x7b_ragged,
64+
train_config=SimpleTrainConfig(
65+
resources=ResourceConfig.with_tpu(tpu_type="v5p-64"),
66+
train_batch_size=32,
67+
num_train_steps=20,
68+
learning_rate=3e-4,
69+
weight_decay=0.1,
70+
steps_per_eval=20,
71+
steps_per_export=20,
72+
),
73+
tokenized_dataset=nemotron_cc_mixture,
74+
)
75+
76+
if __name__ == "__main__":
77+
logger.info("Launching Mixtral 8x7B ragged-dot speedrun (seq_len=4096).")
78+
logger.info(
79+
"Settings: batch_size=%s, seq_len=%s, steps=%s, cross_entropy_block_size=%s",
80+
speedrun_config.train_config.train_batch_size,
81+
mixtral_8x7b_ragged.seq_len,
82+
speedrun_config.train_config.num_train_steps,
83+
mixtral_8x7b_ragged.cross_entropy_block_size,
84+
)
85+
executor_main(steps=default_speedrun("mixtral_8x7b_ragged_speedrun_bs32_seq4096_v5p-64", speedrun_config))
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#!/usr/bin/env bash
2+
# Example run (from repo root, keep tmp artifacts local):
3+
# TMPDIR=/tmp RAY_TMPDIR=/tmp ./experiments/speedrun/olmoe/scripts/run_olmoe_v5p64_profile.sh
4+
#
5+
# Convenience wrapper for launching an OLMoE speedrun with profiling enabled.
6+
7+
set -euo pipefail
8+
9+
REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../../../.." && pwd)"
10+
cd "$REPO_ROOT"
11+
12+
: "${TMPDIR:=/tmp}"
13+
: "${RAY_TMPDIR:=/tmp}"
14+
15+
uv run python -m marin.run.ray_run \
16+
--cluster "infra/marin-us-central1.yaml" \
17+
--extra tpu \
18+
--env_vars WANDB_MODE online \
19+
--env_vars WANDB_API_KEY "${WANDB_API_KEY:-}" \
20+
--env_vars WANDB_ENTITY "${WANDB_ENTITY:-}" \
21+
--env_vars WANDB_PROJECT "${WANDB_PROJECT:-}" \
22+
--env_vars HF_TOKEN "${HF_TOKEN:-}" \
23+
--env_vars PIP_NO_CACHE_DIR 1 \
24+
--env_vars RAY_TMPDIR "${RAY_TMPDIR}" \
25+
--env_vars TMPDIR "${TMPDIR}" \
26+
--env_vars JAX_COMPILATION_CACHE_DIR gs://marin-us-central1/jax-cache/olmoe_1b7b \
27+
--env_vars JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS 0 \
28+
--env_vars JAX_PERSISTENT_CACHE_MIN_ENTRY_SIZE_BYTES -1 \
29+
--env_vars JAX_RAGGED_DOT_USE_RAGGED_DOT_INSTRUCTION 1 \
30+
-- \
31+
python experiments/speedrun/olmoe_1b7b_nemotron_40b.py \
32+
--model olmoe_1b7b \
33+
--dataset nemotron_cc \
34+
--tpu-type v5p-64 \
35+
--global-batch-size 512 \
36+
--num-train-steps 40 \
37+
--profile \
38+
--profile-start-step 15 \
39+
--profile-num-steps 20 \
40+
--force_run_failed true

0 commit comments

Comments
 (0)