Skip to content

AbdelStark/jepa-rs

Repository files navigation

jepa-rs

Joint Embedding Predictive Architecture in Rust

CI License: MIT docs.rs crates.io


Alpha Rust implementation of JEPA (Joint Embedding Predictive Architecture) — the self-supervised learning framework from Yann LeCun and Meta AI for learning world models that predict in representation space rather than pixel space.

jepa-rs provides modular, backend-agnostic building blocks for I-JEPA (images), V-JEPA (video), C-JEPA (causal object-centric), and hierarchical world models, built on top of the burn deep learning framework. It includes a CLI, an interactive TUI dashboard, a browser demo crate for local experimentation, safetensors checkpoint loading, ONNX metadata inspection, and a pretrained model registry for Facebook Research models.

                    ┌──────────────┐
                    │   Context    │──── Encoder ────┐
                    │   (visible)  │                 │
   Image/Video ─────┤              │         ┌───────▼───────┐
                    │   Target     │         │   Predictor   │──── predicted repr
                    │   (masked)   │──┐      └───────────────┘          │
                    └──────────────┘  │                                 │
                                      │      ┌───────────────┐          │
                                      └──────│ Target Encoder│── target repr
                                        EMA  │   (frozen)    │          │
                                             └───────────────┘          │
                                                                        │
                                             ┌───────────────┐          │
                                             │  Energy Loss  │◄─────────┘
                                             └───────────────┘

Why jepa-rs?

jepa-rs Python (PyTorch)
Runtime Native binary, no Python/CUDA dependency Requires Python + PyTorch + CUDA
Inference Safetensors checkpoint loading, ONNX metadata PyTorch runtime
Memory Rust ownership, no GC pauses Python GC + PyTorch allocator
Backend Any burn backend (CPU, GPU, WebGPU, WASM) CUDA-centric
Type safety Compile-time tensor shape checks Runtime shape errors
Deployment Single static binary Docker + Python environment

Pretrained Models

jepa-rs supports loading official Facebook Research pretrained JEPA models:

Model Architecture Params Resolution Dataset Weights
I-JEPA ViT-H/14 ViT-Huge, patch 14 632M 224x224 ImageNet-1K Download | HuggingFace
I-JEPA ViT-H/16-448 ViT-Huge, patch 16 632M 448x448 ImageNet-1K Download | HuggingFace
I-JEPA ViT-H/14 ViT-Huge, patch 14 632M 224x224 ImageNet-22K Download
I-JEPA ViT-G/16 ViT-Giant, patch 16 1.0B 224x224 ImageNet-22K Download
V-JEPA ViT-L/16 ViT-Large, patch 16 304M 224x224 VideoMix2M Download
V-JEPA ViT-H/16 ViT-Huge, patch 16 632M 224x224 VideoMix2M Download

Quick Start

Installation

# Cargo.toml
[dependencies]
jepa-core   = "0.1.0"
jepa-vision = "0.1.0"
jepa-compat = "0.1.0"  # For ONNX + checkpoint loading

CLI

The jepa binary provides a unified CLI for the workspace:

# Install the CLI from crates.io
cargo install jepa

# Or install from the local workspace checkout
cargo install --path crates/jepa

# Launch the interactive TUI dashboard
jepa

# List pretrained models in the registry
jepa models

# Inspect a safetensors checkpoint
jepa inspect model.safetensors

# Analyze checkpoint with key remapping
jepa checkpoint model.safetensors --keymap ijepa --verbose

# Launch a training run
jepa train --preset vit-base-16 --steps 10 --batch-size 1 --lr 1e-3

# Train from a normal image directory tree with deterministic resize/crop/normalize
jepa train --preset vit-base-16 --steps 100 --batch-size 4 \
  --dataset-dir ./images/train --resize 256 --crop-size 224 --shuffle

# Train from a safetensors image tensor dataset [N, C, H, W]
jepa train --preset vit-base-16 --steps 100 --batch-size 1 \
  --dataset train.safetensors --dataset-key images

# Encode inputs through a safetensors checkpoint
jepa encode --model model.safetensors --preset vit-base-16

# Or through an ONNX model
jepa encode --model model.onnx --height 224 --width 224

The CLI train command now runs real strict masked-image optimization with AdamW and EMA. It chooses one input source per run:

  • --dataset-dir <PATH> for a recursive image-folder dataset (jpg, jpeg, png, webp) with decode, RGB conversion, shorter-side resize, center crop, CHW tensor conversion, and normalization
  • --dataset <FILE> --dataset-key <KEY> for a safetensors image tensor shaped [N, C, H, W]
  • no dataset flags for the synthetic random-tensor fallback

Image-folder preprocessing defaults to the preset image size for --crop-size and the ImageNet RGB normalization statistics when --mean and --std are omitted. Dataset loading is currently single-threaded. jepa encode executes real encoder weights for .safetensors and .onnx inputs; other extensions still fall back to the preset demo path.

Runnable Examples

The jepa crate now ships runnable examples under crates/jepa/examples/ that exercise the real training command instead of mocking the CLI path:

# Create a tiny recursive image-folder dataset under target/example-data/jepa/
cargo run -p jepa --example prepare_demo_image_folder

# Train for 2 steps on that generated image-folder dataset
cargo run -p jepa --example train_image_folder_demo

# Train for 2 steps with the synthetic fallback path
cargo run -p jepa --example train_synthetic_demo

The image-folder example deliberately uses a very small generated dataset (6 PNG files across nested subdirectories). That is enough for a meaningful smoke demo of recursive dataset discovery, decode, resize, crop, normalize, batching, masking, optimizer updates, and EMA without checking a large image corpus into git. It is not large enough to demonstrate real representation learning quality; it is an execution demo, not a benchmark dataset.

The TUI now incorporates these demos in the Training tab as a guided demo runner. Launch jepa, switch to tab 3, choose a demo with j/k, and press Enter to run it. The panel streams real run logs, step metrics, loss/energy charts, and a short interpretation of what happened.

The TUI Inference tab on 4 adds a separate guided walkthrough for encoder inference. It runs deterministic demo image patterns through a preset ViT, streams phase changes, per-sample latency and embedding statistics, and explains what the representation telemetry means. The walkthrough is intentionally a pipeline demo rather than a pretrained semantic benchmark.

If you want to run the CLI directly after generating the demo dataset:

cargo run -p jepa -- train --preset vit-small-16 --steps 2 --batch-size 2 \
  --dataset-dir target/example-data/jepa/demo-image-folder \
  --resize 256 --crop-size 224 --shuffle --dataset-limit 6

Loading SafeTensors Checkpoints

use jepa_compat::safetensors::load_checkpoint;
use jepa_compat::keymap::ijepa_vit_keymap;

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let mappings = ijepa_vit_keymap();
    let checkpoint = load_checkpoint("model.safetensors", &mappings)?;

    println!("Loaded {} tensors", checkpoint.len());
    for key in checkpoint.keys() {
        println!("  {}: {:?}", key, checkpoint.get(key).unwrap().shape);
    }
    Ok(())
}

Building JEPA Models from Scratch

use burn::prelude::*;
use burn_ndarray::NdArray;
use jepa_core::masking::{BlockMasking, MaskingStrategy};
use jepa_core::types::InputShape;
use jepa_vision::image::IJepaConfig;
use jepa_vision::vit::VitConfig;

type B = NdArray<f32>;

fn main() {
    let device = burn_ndarray::NdArrayDevice::Cpu;

    // Configure I-JEPA with ViT-Huge/14 (matches Facebook pretrained)
    let config = IJepaConfig {
        encoder: VitConfig::vit_huge_patch14(),
        predictor: jepa_vision::image::TransformerPredictorConfig {
            encoder_embed_dim: 1280,
            predictor_embed_dim: 384,
            num_layers: 12,
            num_heads: 12,
            max_target_len: 256,
        },
    };
    let model = config.init::<B>(&device);

    // Generate masks (I-JEPA block masking)
    let shape = InputShape::Image { height: 16, width: 16 }; // 224/14 = 16
    let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
    let masking = BlockMasking {
        num_targets: 4,
        target_scale: (0.15, 0.2),
        target_aspect_ratio: (0.75, 1.5),
    };
    let mask = masking.generate_mask(&shape, &mut rng);

    println!("Context tokens: {}, Target tokens: {}",
             mask.context_indices.len(), mask.target_indices.len());
}

Browse Available Models

use jepa_compat::registry::{list_models, find_model};

fn main() {
    for model in list_models() {
        println!("{}: {} ({}, {})",
            model.name,
            model.param_count_human(),
            model.architecture,
            model.pretrained_on);
    }

    // Search for a specific model
    if let Some(m) = find_model("vit-h/14") {
        println!("\nFound: {} with {} patches",
            m.name, m.num_patches());
    }
}

Architecture

jepa-rs/
├── jepa-core        Core traits, tensor wrappers, masking, energy, EMA
│   ├── Encoder          Trait for context/target encoders
│   ├── Predictor        Trait for latent predictors
│   ├── EnergyFn         L2, Cosine, SmoothL1 energy functions
│   ├── MaskingStrategy  Block, MultiBlock, Spatiotemporal, Object masking
│   ├── CollapseReg      VICReg, BarlowTwins collapse prevention
│   └── EMA              Exponential moving average with cosine schedule
│
├── jepa-vision      Vision transformers and JEPA models
│   ├── VitEncoder       ViT-S/B/L/H/G with 2D RoPE
│   ├── IJepa            I-JEPA pipeline (image)
│   ├── VJepa            V-JEPA pipeline (video, 3D tubelets)
│   ├── SlotAttention    Slot attention encoder for object-centric representations
│   └── Predictor        Transformer-based cross-attention predictor
│
├── jepa-world       World models and planning
│   ├── ActionPredictor  Action-conditioned latent prediction
│   ├── ObjectDynamics   Transformer dynamics predictor for object slots
│   ├── Planner          Random shooting planner with cost functions
│   ├── HierarchicalJepa Multi-level H-JEPA
│   └── ShortTermMemory  Sliding-window memory for temporal context
│
├── jepa-train       Training orchestration
│   ├── TrainConfig      Learning rate schedules, EMA config
│   ├── JepaComponents   Generic forward step orchestration
│   ├── CausalJepa       C-JEPA training loop (frozen encoder, object masking)
│   └── CheckpointMeta   Save/resume metadata
│
├── jepa-compat      Model compatibility and interop
│   ├── ModelRegistry     Pretrained model catalog (Facebook Research)
│   ├── SafeTensors       Load .safetensors checkpoints
│   ├── KeyMap            PyTorch → burn key remapping
│   └── OnnxModelInfo     ONNX metadata inspection and initializer loading
│
├── jepa             CLI and interactive TUI dashboard
│   ├── CLI               models, inspect, checkpoint, train, encode commands
│   └── TUI               Dashboard, Models, Training, Inference, Checkpoint, About tabs
│
└── jepa-web         Browser demo crate
    ├── WASM API          JS-callable training and inference helpers
    ├── Demo UI           HTML, JS, and CSS assets for local browser demos
    └── Backend status    CPU-backed path today; WebGPU scaffolding remains internal

All tensor-bearing APIs are generic over B: Backend, allowing transparent execution on CPU (NdArray), GPU (WGPU), or WebAssembly backends.

ONNX Support

jepa-rs provides ONNX metadata inspection and initializer loading through jepa-compat. This allows inspecting model structure, input/output specs, and importing weight initializers from .onnx files.

Current scope: metadata inspection and weight import are production-ready. Tract-based ONNX graph execution exists (OnnxSession, OnnxEncoder) but is not yet production-grade — it is functional for prototyping and testing.

Examples

Example Description Run command
jepa Interactive TUI dashboard cargo run -p jepa
jepa models Browse pretrained model registry cargo run -p jepa -- models
jepa train Launch a training run cargo run -p jepa -- train --preset vit-base-16
prepare_demo_image_folder Generate a tiny recursive dataset for --dataset-dir demos cargo run -p jepa --example prepare_demo_image_folder
train_image_folder_demo Run the real jepa train image-folder path on generated images cargo run -p jepa --example train_image_folder_demo
train_synthetic_demo Run the real jepa train synthetic fallback path cargo run -p jepa --example train_synthetic_demo
ijepa_demo Full I-JEPA forward pass pipeline cargo run -p jepa-vision --example ijepa_demo
ijepa_train_loop Training loop with metrics cargo run -p jepa-vision --example ijepa_train_loop
world_model_planning World model with random shooting cargo run -p jepa-world --example world_model_planning
model_registry Browse pretrained models (library) cargo run -p jepa-compat --example model_registry

The browser demo lives in crates/jepa-web. Its exported WASM path currently uses the deterministic CPU backend for reliable local demos and tests; see the architecture and production-gap docs below before treating it as a WebGPU-ready surface.

Build & Test

# Build everything
cargo build --workspace

# Run all tests
cargo test --workspace

# Lint
cargo clippy --workspace --all-targets -- -D warnings

# Format check
cargo fmt -- --check

# Generate docs
cargo doc --workspace --no-deps --open

# Run differential parity tests
scripts/run_parity_suite.sh

# Target a single crate
cargo test -p jepa-core
cargo test -p jepa-vision
cargo test -p jepa-compat
cargo test -p jepa-web

Extended quality gates

# Code coverage (requires cargo-llvm-cov)
cargo llvm-cov --workspace --all-features --fail-under-lines 80

# Fuzz testing (requires cargo-fuzz)
(cd fuzz && cargo fuzz run masking -- -runs=1000)

# Benchmark smoke test
cargo bench --workspace --no-run

Project Docs

Project Status

As of 2026-03-16, this project is alpha. It is suitable for research, local demos, and extension work; it is not yet suitable for unqualified production deployment of every advertised surface. In particular, strict video parity is still pending, ONNX graph execution remains a prototype path, and the browser demo uses the CPU-backed WASM path today.

What works

  • Complete I-JEPA, V-JEPA, and C-JEPA architectures with strict masked-encoder paths
  • CLI with 6 commands (models, inspect, checkpoint, train, encode, tui)
  • Interactive TUI dashboard with 6 tabs (Dashboard, Models, Training, Inference, Checkpoint, About)
  • Browser demo crate with deterministic CPU-backed WASM training and inference for local experimentation
  • SafeTensors checkpoint loading with automatic key remapping
  • ONNX metadata inspection and initializer loading
  • Pretrained model registry with download URLs
  • Differential parity tests against 3 checked-in strict image fixtures
  • Comprehensive test suite (500+ tests), property-based testing, fuzz targets
  • All standard ViT configs: ViT-S/16, ViT-B/16, ViT-L/16, ViT-H/14, ViT-H/16, ViT-G/16

Known limitations

  • The generic trainer slices tokens after encoder forward; strict pre-attention masking is available via IJepa::forward_step_strict and VJepa::forward_step_strict
  • ONNX support covers metadata inspection and initializer loading only, not graph execution
  • The jepa-web crate keeps burn-wgpu scaffolding in-tree, but the exported browser demo currently runs on the burn-ndarray CPU backend only
  • Differential parity runs in CI for strict image fixtures; broader video parity is pending
  • First-time crates.io release must be published in dependency order because the workspace crates depend on each other by version

JEPA Variants: What We Implement

The JEPA family has grown across several papers. Here is exactly what jepa-rs implements and how each component maps to a specific paper and reference codebase.

I-JEPA (Image)

Paper Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture (Assran et al., CVPR 2023)
Reference code facebookresearch/ijepa (archived)
jepa-rs struct IJepa<B> in jepa-vision (crates/jepa-vision/src/image.rs)
What it does Self-supervised image representation learning. A ViT context-encoder sees only visible patches; a lightweight predictor predicts representations of masked target patches. The target-encoder is an EMA copy of the context-encoder.
Masking BlockMasking — contiguous rectangular blocks on the 2D patch grid.
Faithful path IJepa::forward_step_strict — filters tokens before encoder self-attention (matches the paper).
Approximate path JepaComponents::forward_step in jepa-train — encodes full input then slices (post-encoder masking; cheaper but not faithful).
Parity status 3 checked-in strict image fixtures verified in CI.

V-JEPA (Video)

Paper Revisiting Feature Prediction for Learning Visual Representations from Video (Bardes et al., 2024)
Reference code facebookresearch/jepa
jepa-rs struct VJepa<B> in jepa-vision (crates/jepa-vision/src/video.rs)
What it does Extends I-JEPA to video. A ViT encoder processes 3D tubelets (space + time) with 3D RoPE.
Masking SpatiotemporalMasking — contiguous 3D regions in the spatiotemporal grid.
Faithful path VJepa::forward_step_strict — pre-attention masking.
Parity status Implemented but strict video parity not yet proven (pending).

V-JEPA 2 features

Paper V-JEPA 2: Self-Supervised Video Models Enable Understanding, Prediction and Planning (Bardes et al., 2025)
Reference code facebookresearch/vjepa2
jepa-rs support Not a separate struct. The VJepa<B> struct can be configured with V-JEPA 2 features.
What we take from V-JEPA 2 Cosine momentum schedule for EMA — CosineMomentumSchedule in jepa-core (Ema::with_cosine_schedule). Momentum ramps from base (e.g. 0.996) to 1.0 over training. Also: MultiBlockMasking strategy, ViT-Giant/14 preset.
What we don't implement The full V-JEPA 2 training recipe, attentive probing, or the planning/action heads from the paper.

Hierarchical JEPA (H-JEPA) — experimental

Paper Inspired by A Path Towards Autonomous Machine Intelligence (LeCun, 2022) — the original JEPA position paper describes hierarchical prediction as a long-term goal. No standalone H-JEPA paper exists yet.
jepa-rs struct HierarchicalJepa<B> in jepa-world (crates/jepa-world/src/hierarchy.rs)
What it does Stacks multiple JEPA levels at different temporal strides (e.g. stride 2, 6, 24). Each level has its own encoder and predictor. This is experimental — no reference implementation exists.

Action-Conditioned World Model — experimental

Paper Draws from both the LeCun position paper and V-JEPA 2 (planning component).
jepa-rs structs Action<B>, ActionConditionedPredictor<B> trait, RandomShootingPlanner in jepa-world (crates/jepa-world/src/action.rs, crates/jepa-world/src/planner.rs)
What it does Predicts next-state representations given current state + action. Supports random-shooting (CEM) planning. This is experimental.

C-JEPA (Causal)

Paper Causal-JEPA: Learning World Models through Object-Level Latent Interventions (Nam et al., 2025)
Reference code galilai-group/cjepa
jepa-rs structs ObjectMasking in jepa-core, SlotAttention<B> / SlotEncoder<B> in jepa-vision, CausalJepaComponents in jepa-train, ObjectDynamicsPredictor<B> in jepa-world
What it does Object-centric JEPA: masks whole objects (not patches), uses a frozen encoder with slot attention, identity-anchored masked tokens, and joint history + future MSE loss. Enables causal reasoning and efficient CEM planning in object-representation space (~98% token reduction vs patch-based models).
Masking ObjectMasking — randomly partitions N object slots into context/target subsets.
Training Frozen encoder (no EMA), only slot attention and predictor are trained. CausalJepaComponents::forward_step in jepa-train.
Planning ObjectDynamicsPredictor + RandomShootingPlanner in jepa-world for CEM-based MPC in object space.

What about EB-JEPA?

EB-JEPA (Terver et al., 2026) is a separate lightweight Python library for energy-based JEPA. jepa-rs is not an implementation of EB-JEPA. We reference it for comparison only. The energy functions in jepa-core (L2, Cosine, SmoothL1) are standard loss formulations, not the EB-JEPA energy framework.

Quick summary

Variant Paper jepa-rs struct Status
I-JEPA Assran et al. 2023 IJepa<B> Strict path implemented, parity verified
V-JEPA Bardes et al. 2024 VJepa<B> Strict path implemented, parity pending
V-JEPA 2 Bardes et al. 2025 VJepa<B> + cosine EMA schedule Select features only
H-JEPA LeCun 2022 (position paper) HierarchicalJepa<B> Experimental, no reference impl
C-JEPA Nam et al. 2025 ObjectMasking, SlotAttention, CausalJepaComponents, ObjectDynamicsPredictor Core support implemented
World model LeCun 2022 + V-JEPA 2 ActionConditionedPredictor, RandomShootingPlanner Experimental
EB-JEPA Terver et al. 2026 Not implemented Referenced for comparison only

References

Papers

Paper Focus
A Path Towards Autonomous Machine Intelligence JEPA position paper — hierarchical world models (LeCun, 2022)
I-JEPA Self-supervised image learning with masked prediction in latent space (Assran et al., CVPR 2023)
V-JEPA Extension to video with spatiotemporal masking (Bardes et al., 2024)
V-JEPA 2 Video understanding, prediction, and planning (Bardes et al., 2025)
C-JEPA Object-centric world models via latent interventions (Nam et al., 2025)
Slot Attention Object-centric learning with slot attention (Locatello et al., NeurIPS 2020)
EB-JEPA Lightweight energy-based JEPA library — referenced for comparison (Terver et al., 2026)

Official reference implementations

Repo Models Relationship to jepa-rs
facebookresearch/ijepa I-JEPA (archived) Primary reference for IJepa<B> and key remapping
facebookresearch/jepa V-JEPA Primary reference for VJepa<B>
facebookresearch/vjepa2 V-JEPA 2 Reference for cosine EMA schedule, ViT-G config
galilai-group/cjepa C-JEPA Primary reference for ObjectMasking, SlotAttention, CausalJepaComponents
facebookresearch/eb_jepa EB-JEPA tutorial Not implemented — comparison only

Contributing

See CONTRIBUTING.md for guidelines.

License

MIT License. See LICENSE for details.


Built with burn and tract

About

Joint Embedding Predictive Architecture for World Models, written in Rust.

Topics

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors