Skip to content

gaarutyunov/vjepa2-mlx

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

V-JEPA 2 MLX

MLX implementation of V-JEPA 2 (Video Joint-Embedding Predictive Architecture) for Apple Silicon.

This repository contains the MLX-optimized version of V-JEPA 2, designed to run efficiently on Apple Silicon (M1/M2/M3) using the MLX framework.

Features

  • MLX-optimized models: Native implementation using Apple's MLX framework for maximum performance on Apple Silicon
  • Video understanding: Pretrained vision transformers for video representation learning
  • SSv2 classifier training: Training pipeline for Something-Something-v2 action recognition
  • Frozen encoder fine-tuning: Efficient classifier training with frozen pretrained encoders
  • Notebook examples: Interactive Jupyter notebooks for experimentation

Requirements

  • Apple Silicon Mac (M1/M2/M3 or later)
  • macOS 13.0 or later
  • Python 3.11 or later
  • MLX 0.20.0 or later

Installation

From source

# Clone the repository
git clone https://github.com/facebookresearch/vjepa2-mlx.git
cd vjepa2-mlx

# Create a virtual environment (recommended)
python3 -m venv venv
source venv/bin/activate

# Install the package
pip install -e .

# For development with additional tools
pip install -e ".[dev]"

# For notebook support
pip install -e ".[notebook]"

Quick install

pip install -r requirements.txt
pip install -e .

Quick Start

1. Download Pretrained Weights

Download the MLX-converted pretrained V-JEPA 2 weights from the main repository:

# Create weights directory
mkdir -p weights

# Download ViT-Large MLX weights
# (You'll need to convert PyTorch weights using the provided conversion script)

2. Training SSv2 Classifier

Train a video action classifier on Something-Something-v2:

python train_ssv2_classifier.py \
    --videos-dir /path/to/ssv2/videos \
    --labels-dir /path/to/ssv2/labels \
    --pretrained-weights weights/vitl_mlx.safetensors \
    --output-dir output_ssv2_classifier \
    --batch-size 4 \
    --num-epochs 10 \
    --use-wandb

Or use the config file:

python train_ssv2_classifier.py --config configs/train/ssv2_classifier_default.yaml

3. Interactive Notebooks

Explore the examples in Jupyter notebooks:

jupyter notebook notebooks/vjepa2_mlx_demo.ipynb

Testing

Local Testing with Docker

Test the code in an environment that matches GitHub Actions CI:

# Build the Docker image
./scripts/docker_test.sh build

# Run all tests
./scripts/docker_test.sh test

# Open interactive shell for debugging
./scripts/docker_test.sh shell

See DOCKER.md for detailed Docker testing documentation.

Native Testing

# Install test dependencies
pip install -r requirements-test.txt

# Run tests
pytest tests/

See tests/README.md for more testing documentation.

Repository Structure

vjepa2-mlx/
├── src/
│   └── vjepa2_mlx/           # Main package
│       ├── models/           # Model implementations
│       │   ├── vision_transformer.py
│       │   ├── attentive_pooler.py
│       │   ├── predictor.py
│       │   └── ac_predictor.py
│       └── utils/            # Utility modules
│           ├── modules.py
│           ├── patch_embed.py
│           └── pos_embs.py
├── configs/                  # Configuration files
│   └── train/               # Training configs
│       └── ssv2_classifier_default.yaml
├── notebooks/               # Jupyter notebooks
│   ├── vjepa2_mlx_demo.ipynb
│   └── ssv2_classifier_training_mlx.ipynb
├── tests/                   # Unit tests
├── scripts/                 # Utility scripts
├── train_ssv2_classifier.py # Main training script
├── requirements.txt         # Package dependencies
├── requirements-dev.txt     # Development dependencies
├── pyproject.toml          # Project metadata
└── setup.py                # Setup script

Models

Vision Transformer (ViT)

The core encoder is a Vision Transformer adapted for video input:

  • ViT-Large/16: 24 layers, 1024 hidden dim, 16 attention heads
  • 3D Patch Embedding: Tubelet size 2 for temporal dimension
  • RoPE: Rotary position embeddings for better spatial understanding
  • Frozen Encoder: Pretrained weights are frozen during classifier training

Attentive Classifier

A lightweight classifier head with attention pooling:

  • Learns to attend to relevant spatiotemporal features
  • Efficient training with frozen encoder
  • Suitable for action recognition tasks

Training

Configuration

Edit configs/train/ssv2_classifier_default.yaml to customize:

  • Dataset paths
  • Model architecture (frames, resolution, etc.)
  • Training hyperparameters (batch size, learning rate, epochs)
  • Weights & Biases logging
  • Output directory and checkpointing

Command Line Arguments

python train_ssv2_classifier.py --help

Key arguments:

  • --videos-dir: Path to video files
  • --labels-dir: Path to label JSON files
  • --pretrained-weights: Path to pretrained encoder weights
  • --batch-size: Training batch size (default: 4)
  • --num-epochs: Number of training epochs (default: 10)
  • --learning-rate: Learning rate (default: 1e-3)
  • --use-wandb: Enable Weights & Biases logging
  • --output-dir: Output directory for checkpoints
  • --resume-from: Resume from checkpoint

Gradient Accumulation

For larger effective batch sizes on limited memory:

python train_ssv2_classifier.py \
    --batch-size 2 \
    --gradient-accumulation-steps 4  # Effective batch size = 8

Performance

MLX provides significant performance benefits on Apple Silicon:

  • Memory Efficiency: Unified memory architecture reduces overhead
  • Native Optimization: Hardware-accelerated operations
  • Energy Efficiency: Lower power consumption compared to external GPUs

Typical training speeds on M2 Max (64GB):

  • ViT-Large/16: ~2-3 samples/sec with batch size 4
  • Memory usage: ~8-12GB during training

Converting PyTorch Weights

To convert pretrained PyTorch weights to MLX format:

from vjepa2_mlx.convert_weights import convert_state_dict
import mlx.core as mx

# Load PyTorch weights
torch_weights = torch.load('vjepa2_vitl.pth')

# Convert to MLX
mlx_weights = convert_state_dict(torch_weights)

# Save
mx.save_safetensors('vitl_mlx.safetensors', mlx_weights)

Citation

If you use this code, please cite the original V-JEPA 2 paper:

@article{vjepa2,
  title={V-JEPA 2: Video Joint-Embedding Predictive Architecture},
  author={Facebook Research},
  journal={arXiv preprint},
  year={2024}
}

License

This project is licensed under the MIT License - see the LICENSE file for details.

Acknowledgments

  • Original V-JEPA 2 implementation by Facebook Research
  • MLX framework by Apple
  • Something-Something-v2 dataset by TwentyBN

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Support

For issues and questions:

  • Open an issue on GitHub
  • Check the main V-JEPA 2 repository for general questions
  • Refer to MLX documentation for framework-specific questions

About

MLX implementation of V-JEPA 2

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published