High-performance optimization suite for LLaDA (Large Language Diffusion Assistant) Mixture-of-Experts (MoE) layers. This implementation transforms the inference regime from a CPU-bound state to a hardware-saturated, compute-bound state
All experiments were conducted on an NVIDIA A100 (40GB/80GB) Measurements were averaged over eight independent runs to ensure statistical consistency Transformers Version: 4.57.6 Task: LLaDA-MoE inference (max length 128)
By transitioning from fragmented memory access to the Sort-Compute-Scatter pipeline, we achieved:
-
CUDA Speedup: Mean execution time ratio of
$1.8883 \pm 0.0030$ -
Memory Throughput: Effective bandwidth utilization increased by a ratio of
$1.9251 \pm 0.0066$ -
Hardware Efficiency: SM utilization improved by a factor of
$1.2541 \pm 0.0009$ (at 1 batch size)
The optimization maintains full mathematical equivalence with the native implementation:
- Benchmark: GSM8K test set (50 samples)
- Accuracy: 50% for both baseline and optimized versions
- Verification: Confirmed that Sort-Compute-Scatter does not compromise model performance
-
Clone the repository:
git clone https://github.com/your-username/llada-moe-optimization.git cd llada-moe-optimization -
Install dependencies:
pip install torch transformers packaging
Note: This optimization is validated against
transformers==4.57.6.
import torch
from transformers import AutoModelForCausalLM
from llada_moe_optimization import optimize_llada_moe
#your LLaDA MoE model
model = AutoModelForCausalLM.from_pretrained("inclusionAI/LLaDA-MoE-7B-A1B-Instruct", device_map="auto")
#patch the model with optimized MoE blocks
optimized_model = optimize_llada_moe(model)
##use code from https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base from section №2 No Speedup: transformersThis optimization is particularly effective for:
- Different GPUs: Was released only on torch
- High Expert Count: Parallelizes the "Expert Loop" logic more effectively
- Triton Integration: (Coming Soon) Custom kernels for specialized MoE activation
This implementation relies on monkey-patching specific internal classes of the LLaDAMoESparseMoeBlock
| Component | Tested Version | Status |
|---|---|---|
| Transformers | 4.57.6 | Stable |
| PyTorch | 2.0+ | Stable |
Contributions are welcome! If you have interesting ideas for further optimizations (e.g., Triton kernels for the router), please open an issue or submit a PR
gmail: alexeymanakonov@gmail.com tg: @alexeo_0man
This project is licensed under the MIT License - see the LICENSE file for details