Skip to content

Conversation

@gaarutyunov
Copy link
Owner

This commit adds a comprehensive Rust implementation of the V-JEPA 2 model
using mlx-rs (Rust bindings for Apple's MLX framework).

Rust Implementation (vjepa2-rs/)

Completed Components:

  • ✅ Core transformer modules (src/modules.rs):

    • MLP with GELU/SiLU activation
    • SwiGLU FFN (Swish-Gated Linear Unit)
    • Standard multi-head Attention
    • RoPEAttention (Rotary Position Embeddings for 3D video inputs)
    • Transformer Block with residual connections and drop path
    • rotate_queries_or_keys function for RoPE
  • ✅ Patch embedding layers (src/patch_embed.rs):

    • PatchEmbed: 2D Conv-based patch embedding for images
    • PatchEmbed3D: 3D Conv-based patch embedding for videos
  • ✅ Positional embeddings (src/pos_embs.rs):

    • 1D and 2D sinusoidal positional embeddings
  • ✅ Vision Transformer structure (src/vision_transformer.rs):

    • Basic VisionTransformer with patch embedding and blocks
  • ✅ Error handling (src/error.rs):

    • Custom error types with thiserror
  • ✅ Documentation:

    • Comprehensive README.md with usage instructions
    • Inline documentation for all modules
    • Design decisions documented (e.g., RoPE bug replication)

Key Features:

  • Faithful port of Python MLX implementation
  • Replicates PyTorch RoPE behavior for pretrained weight compatibility
  • LayerNorm eps=1e-6 to match PyTorch exactly
  • 3D position separation for video inputs (depth, height, width)
  • Type-safe error handling

Apple Silicon Requirement:

⚠️ The Rust port requires macOS with Apple Silicon (M1/M2/M3) due to
MLX's dependency on Metal and Accelerate frameworks. It will NOT compile
on Linux or Intel-based systems.

Testing Infrastructure

New Python Component Tests:

  • Added comprehensive component output tests (tests/test_component_outputs.py)
  • Tests validate individual components:
    • RoPE rotation function
    • MLP (GELU vs SiLU)
    • SwiGLU FFN
    • Standard Attention
    • RoPE Attention
    • Transformer Block
    • Patch embedding layers
    • Positional embeddings
  • Parametric tests for various configurations
  • Determinism and shape validation tests

CI/CD Updates:

  • Added component output tests to GitHub Actions workflow
  • Added conditional Rust build job for macOS (Apple Silicon)
  • Rust job only runs on workflow_dispatch or with [test-rust] in commit message
  • Continues on error for Rust tests (implementation ongoing)

Documentation Updates

Main README.md:

  • Added Rust Port section with status and requirements
  • Updated repository structure to include vjepa2-rs/
  • Documented Apple Silicon requirement
  • Added "Why Rust?" section explaining benefits

Rust README (vjepa2-rs/README.md):

  • Detailed component status and TODO list
  • Implementation guide with examples
  • Key design decisions documented
  • Testing strategy outlined
  • Python comparison approach described

Project Structure:

vjepa2-rs/
├── src/
│   ├── lib.rs                    # Main library entry point
│   ├── error.rs                  # Error types
│   ├── modules.rs                # Core transformer modules (808 lines)
│   ├── patch_embed.rs            # Patch embedding layers
│   ├── pos_embs.rs               # Positional embeddings
│   └── vision_transformer.rs     # VisionTransformer model
├── tests/
│   └── python_comparison.rs      # (TODO) Python-Rust comparison tests
├── Cargo.toml                    # Rust package manifest
├── .gitignore                    # Rust-specific gitignore
└── README.md                     # Comprehensive documentation

Next Steps (TODO):

  • Implement predictor models (VisionTransformerPredictor, VisionTransformerPredictorAC)
  • Implement AttentivePooler and AttentiveClassifier
  • Add weight loading from Python SafeTensors checkpoints
  • Implement Python-Rust comparison tests (requires macOS)
  • Add integration tests for full model forward pass
  • Numerical accuracy validation (MAE, cosine similarity)

Note: The Rust implementation serves as a foundation and reference.
Full testing and validation requires macOS with Apple Silicon to compile
and run the mlx-rs-based code.

This commit adds a comprehensive Rust implementation of the V-JEPA 2 model
using mlx-rs (Rust bindings for Apple's MLX framework).

## Rust Implementation (vjepa2-rs/)

### Completed Components:
- ✅ Core transformer modules (src/modules.rs):
  - MLP with GELU/SiLU activation
  - SwiGLU FFN (Swish-Gated Linear Unit)
  - Standard multi-head Attention
  - RoPEAttention (Rotary Position Embeddings for 3D video inputs)
  - Transformer Block with residual connections and drop path
  - rotate_queries_or_keys function for RoPE

- ✅ Patch embedding layers (src/patch_embed.rs):
  - PatchEmbed: 2D Conv-based patch embedding for images
  - PatchEmbed3D: 3D Conv-based patch embedding for videos

- ✅ Positional embeddings (src/pos_embs.rs):
  - 1D and 2D sinusoidal positional embeddings

- ✅ Vision Transformer structure (src/vision_transformer.rs):
  - Basic VisionTransformer with patch embedding and blocks

- ✅ Error handling (src/error.rs):
  - Custom error types with thiserror

- ✅ Documentation:
  - Comprehensive README.md with usage instructions
  - Inline documentation for all modules
  - Design decisions documented (e.g., RoPE bug replication)

### Key Features:
- Faithful port of Python MLX implementation
- Replicates PyTorch RoPE behavior for pretrained weight compatibility
- LayerNorm eps=1e-6 to match PyTorch exactly
- 3D position separation for video inputs (depth, height, width)
- Type-safe error handling

### Apple Silicon Requirement:
⚠️ The Rust port requires macOS with Apple Silicon (M1/M2/M3) due to
MLX's dependency on Metal and Accelerate frameworks. It will NOT compile
on Linux or Intel-based systems.

## Testing Infrastructure

### New Python Component Tests:
- Added comprehensive component output tests (tests/test_component_outputs.py)
- Tests validate individual components:
  - RoPE rotation function
  - MLP (GELU vs SiLU)
  - SwiGLU FFN
  - Standard Attention
  - RoPE Attention
  - Transformer Block
  - Patch embedding layers
  - Positional embeddings
- Parametric tests for various configurations
- Determinism and shape validation tests

### CI/CD Updates:
- Added component output tests to GitHub Actions workflow
- Added conditional Rust build job for macOS (Apple Silicon)
- Rust job only runs on workflow_dispatch or with [test-rust] in commit message
- Continues on error for Rust tests (implementation ongoing)

## Documentation Updates

### Main README.md:
- Added Rust Port section with status and requirements
- Updated repository structure to include vjepa2-rs/
- Documented Apple Silicon requirement
- Added "Why Rust?" section explaining benefits

### Rust README (vjepa2-rs/README.md):
- Detailed component status and TODO list
- Implementation guide with examples
- Key design decisions documented
- Testing strategy outlined
- Python comparison approach described

## Project Structure:
```
vjepa2-rs/
├── src/
│   ├── lib.rs                    # Main library entry point
│   ├── error.rs                  # Error types
│   ├── modules.rs                # Core transformer modules (808 lines)
│   ├── patch_embed.rs            # Patch embedding layers
│   ├── pos_embs.rs               # Positional embeddings
│   └── vision_transformer.rs     # VisionTransformer model
├── tests/
│   └── python_comparison.rs      # (TODO) Python-Rust comparison tests
├── Cargo.toml                    # Rust package manifest
├── .gitignore                    # Rust-specific gitignore
└── README.md                     # Comprehensive documentation
```

## Next Steps (TODO):
- [ ] Implement predictor models (VisionTransformerPredictor, VisionTransformerPredictorAC)
- [ ] Implement AttentivePooler and AttentiveClassifier
- [ ] Add weight loading from Python SafeTensors checkpoints
- [ ] Implement Python-Rust comparison tests (requires macOS)
- [ ] Add integration tests for full model forward pass
- [ ] Numerical accuracy validation (MAE, cosine similarity)

Note: The Rust implementation serves as a foundation and reference.
Full testing and validation requires macOS with Apple Silicon to compile
and run the mlx-rs-based code.
- Fix PatchEmbed tests to use correct constructor arguments:
  - PatchEmbed(patch_size, in_chans, embed_dim) instead of (img_size, ...)
  - PatchEmbed3D expects (B, T, H, W, C) format input
- Fix get_2d_sincos_pos_embed test: function takes (embed_dim, grid_size)
- Fix allclose test: pos_embs returns numpy arrays, use np.allclose
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants