MixtureKit package provides high-level helper functions to merge pretrained and finetuned models
into a unified framework, integrating Mixture-of-Experts (MoE) advanced architectures.
- One‑line merge of multiple HF checkpoints into a single MoE model.
- Supports Branch‑Train‑MiX (BTX), Branch‑Train‑Stitch (BTS) and vanilla MoE.
- Built‑in routing visualizer: inspect which tokens each expert receives — overall (coarse‑grained) and per layer (fine‑grained). See
examples/README_vis.mdfor details.
# Create a fresh conda environment (recommended)
conda create -n mixturekit python=3.12
conda activate mixturekit
# clone & install in editable mode for development
git clone https://github.com/MBZUAI-Paris/MixtureKit
cd MixtureKit
pip install -e .Requirements: Python ≥ 3.10 · PyTorch ≥ 2.5. The correct version of
transformersis pulled automatically.
The script below builds a BTX MoE that routes tokens between a Gemma‑4B base model and two specialized fine-tuned experts (FrenchGemma for French Language and MedGemma for Health Information). For BTS or vanilla architectures, change the moe_method to BTS and traditional respectively. For other model families, comment the model_cls.
# From the repo root
python examples/example_build_moe.pyWhat happens under the hood?
- A config dictionary is created that lists the base expert, two additional experts, the routing layers, etc.
MixtureKit.build_moe()merges the checkpoints and writes the MoE tomodels_merge/gemmax/.- The script reloads the model with
AutoModelForCausalLMand prints a parameter‑breakdown table — only router weights stay trainable.
🔧 Fine-tune / Supervised-Fine-Tuning (SFT)
The mixture_training/ folder contains a ready-to-go scaffold that trains
any merged MoE checkpoint with LoRA-adapters (BTX or BTS).
mixture_training/
├── config_training.yaml # all hyper-params in one place
├── deepspeed_config.yaml # ZeRO-3 config
├── requirements.txt # extra libs (trl, deepspeed, wandb, etc.)
└── train_model.py # launch-script
- Expected format: 🤗
datasetsarrow table saved on disk and loaded withload_from_disk(). config_training.yamlassumes:- a column called
messages(list of chat turns), - each turn is a dict
{"role": "...", "content": "..."}(same schema as ShareGPT).
- a column called
Minimal edits for your own run:
| Key | What it does |
|---|---|
dataset_path |
Path to the dataset produced in step 2 |
model_id |
Path or HF-Hub id of themerged MoE (e.g. models_merge/gemmax) |
output_dir |
Where to write checkpoints / LoRA adapters |
run_name |
Friendly name shown in 🤗wandb / logs |
accelerate launch --config_file mixture_training/deepspeed_config.yaml mixture_training/train_model.pyThe script will:
- Load the MoE checkpoint in bf16 with distributed training if multi GPUs are available,
- Train with 🤗
trl’sSFTTrainer, - Save incremental checkpoints to the local directory
output_dir.
Tip: To switch from BTX/Traditional to BTS finetuning, open
config_training.yamland setis_btstoTrue.
The file examples/config_examples.txt contains more ready‑to‑use configs. Copy one into a small script and call build_moe().
| Key | Scenario | MoE flavour |
|---|---|---|
llama3x |
Two Llama‑3‑1B experts | BTX |
qwen3x |
Three Qwen‑3‑0.6B experts | Traditional MoE |
gemmabts |
Gemma + 2 Gemma-experts | BTS (layer‑stitching) |
- API reference — open
docs/index.htmlor visit the online version.
Pull requests are welcome! Please open an issue first to discuss your ideas.
MixtureKit is released under the BSD 3-Clause License — see the LICENSE file for details.
MixtureKit: A General Framework for Composing, Training, and Visualizing Mixture-of-Experts Models
Happy mixing! 🎛️
