MLX implementation of V-JEPA 2 (Video Joint-Embedding Predictive Architecture) for Apple Silicon.
This repository contains the MLX-optimized version of V-JEPA 2, designed to run efficiently on Apple Silicon (M1/M2/M3) using the MLX framework.
- MLX-optimized models: Native implementation using Apple's MLX framework for maximum performance on Apple Silicon
- Video understanding: Pretrained vision transformers for video representation learning
- SSv2 classifier training: Training pipeline for Something-Something-v2 action recognition
- Frozen encoder fine-tuning: Efficient classifier training with frozen pretrained encoders
- Notebook examples: Interactive Jupyter notebooks for experimentation
- Apple Silicon Mac (M1/M2/M3 or later)
- macOS 13.0 or later
- Python 3.11 or later
- MLX 0.20.0 or later
# Clone the repository
git clone https://github.com/facebookresearch/vjepa2-mlx.git
cd vjepa2-mlx
# Create a virtual environment (recommended)
python3 -m venv venv
source venv/bin/activate
# Install the package
pip install -e .
# For development with additional tools
pip install -e ".[dev]"
# For notebook support
pip install -e ".[notebook]"pip install -r requirements.txt
pip install -e .Download the MLX-converted pretrained V-JEPA 2 weights from the main repository:
# Create weights directory
mkdir -p weights
# Download ViT-Large MLX weights
# (You'll need to convert PyTorch weights using the provided conversion script)Train a video action classifier on Something-Something-v2:
python train_ssv2_classifier.py \
--videos-dir /path/to/ssv2/videos \
--labels-dir /path/to/ssv2/labels \
--pretrained-weights weights/vitl_mlx.safetensors \
--output-dir output_ssv2_classifier \
--batch-size 4 \
--num-epochs 10 \
--use-wandbOr use the config file:
python train_ssv2_classifier.py --config configs/train/ssv2_classifier_default.yamlExplore the examples in Jupyter notebooks:
jupyter notebook notebooks/vjepa2_mlx_demo.ipynbTest the code in an environment that matches GitHub Actions CI:
# Build the Docker image
./scripts/docker_test.sh build
# Run all tests
./scripts/docker_test.sh test
# Open interactive shell for debugging
./scripts/docker_test.sh shellSee DOCKER.md for detailed Docker testing documentation.
# Install test dependencies
pip install -r requirements-test.txt
# Run tests
pytest tests/See tests/README.md for more testing documentation.
vjepa2-mlx/
├── src/
│ └── vjepa2_mlx/ # Main package
│ ├── models/ # Model implementations
│ │ ├── vision_transformer.py
│ │ ├── attentive_pooler.py
│ │ ├── predictor.py
│ │ └── ac_predictor.py
│ └── utils/ # Utility modules
│ ├── modules.py
│ ├── patch_embed.py
│ └── pos_embs.py
├── configs/ # Configuration files
│ └── train/ # Training configs
│ └── ssv2_classifier_default.yaml
├── notebooks/ # Jupyter notebooks
│ ├── vjepa2_mlx_demo.ipynb
│ └── ssv2_classifier_training_mlx.ipynb
├── tests/ # Unit tests
├── scripts/ # Utility scripts
├── train_ssv2_classifier.py # Main training script
├── requirements.txt # Package dependencies
├── requirements-dev.txt # Development dependencies
├── pyproject.toml # Project metadata
└── setup.py # Setup script
The core encoder is a Vision Transformer adapted for video input:
- ViT-Large/16: 24 layers, 1024 hidden dim, 16 attention heads
- 3D Patch Embedding: Tubelet size 2 for temporal dimension
- RoPE: Rotary position embeddings for better spatial understanding
- Frozen Encoder: Pretrained weights are frozen during classifier training
A lightweight classifier head with attention pooling:
- Learns to attend to relevant spatiotemporal features
- Efficient training with frozen encoder
- Suitable for action recognition tasks
Edit configs/train/ssv2_classifier_default.yaml to customize:
- Dataset paths
- Model architecture (frames, resolution, etc.)
- Training hyperparameters (batch size, learning rate, epochs)
- Weights & Biases logging
- Output directory and checkpointing
python train_ssv2_classifier.py --helpKey arguments:
--videos-dir: Path to video files--labels-dir: Path to label JSON files--pretrained-weights: Path to pretrained encoder weights--batch-size: Training batch size (default: 4)--num-epochs: Number of training epochs (default: 10)--learning-rate: Learning rate (default: 1e-3)--use-wandb: Enable Weights & Biases logging--output-dir: Output directory for checkpoints--resume-from: Resume from checkpoint
For larger effective batch sizes on limited memory:
python train_ssv2_classifier.py \
--batch-size 2 \
--gradient-accumulation-steps 4 # Effective batch size = 8MLX provides significant performance benefits on Apple Silicon:
- Memory Efficiency: Unified memory architecture reduces overhead
- Native Optimization: Hardware-accelerated operations
- Energy Efficiency: Lower power consumption compared to external GPUs
Typical training speeds on M2 Max (64GB):
- ViT-Large/16: ~2-3 samples/sec with batch size 4
- Memory usage: ~8-12GB during training
To convert pretrained PyTorch weights to MLX format:
from vjepa2_mlx.convert_weights import convert_state_dict
import mlx.core as mx
# Load PyTorch weights
torch_weights = torch.load('vjepa2_vitl.pth')
# Convert to MLX
mlx_weights = convert_state_dict(torch_weights)
# Save
mx.save_safetensors('vitl_mlx.safetensors', mlx_weights)If you use this code, please cite the original V-JEPA 2 paper:
@article{vjepa2,
title={V-JEPA 2: Video Joint-Embedding Predictive Architecture},
author={Facebook Research},
journal={arXiv preprint},
year={2024}
}This project is licensed under the MIT License - see the LICENSE file for details.
- Original V-JEPA 2 implementation by Facebook Research
- MLX framework by Apple
- Something-Something-v2 dataset by TwentyBN
Contributions are welcome! Please feel free to submit a Pull Request.
For issues and questions:
- Open an issue on GitHub
- Check the main V-JEPA 2 repository for general questions
- Refer to MLX documentation for framework-specific questions