Joint Embedding Predictive Architecture in Rust
Alpha Rust implementation of JEPA (Joint Embedding Predictive Architecture) — the self-supervised learning framework from Yann LeCun and Meta AI for learning world models that predict in representation space rather than pixel space.
jepa-rs provides modular, backend-agnostic building blocks for I-JEPA (images), V-JEPA (video), C-JEPA (causal object-centric), and hierarchical world models, built on top of the burn deep learning framework. It includes a CLI, an interactive TUI dashboard, a browser demo crate for local experimentation, safetensors checkpoint loading, ONNX metadata inspection, and a pretrained model registry for Facebook Research models.
┌──────────────┐
│ Context │──── Encoder ────┐
│ (visible) │ │
Image/Video ─────┤ │ ┌───────▼───────┐
│ Target │ │ Predictor │──── predicted repr
│ (masked) │──┐ └───────────────┘ │
└──────────────┘ │ │
│ ┌───────────────┐ │
└──────│ Target Encoder│── target repr
EMA │ (frozen) │ │
└───────────────┘ │
│
┌───────────────┐ │
│ Energy Loss │◄─────────┘
└───────────────┘
| jepa-rs | Python (PyTorch) | |
|---|---|---|
| Runtime | Native binary, no Python/CUDA dependency | Requires Python + PyTorch + CUDA |
| Inference | Safetensors checkpoint loading, ONNX metadata | PyTorch runtime |
| Memory | Rust ownership, no GC pauses | Python GC + PyTorch allocator |
| Backend | Any burn backend (CPU, GPU, WebGPU, WASM) | CUDA-centric |
| Type safety | Compile-time tensor shape checks | Runtime shape errors |
| Deployment | Single static binary | Docker + Python environment |
jepa-rs supports loading official Facebook Research pretrained JEPA models:
| Model | Architecture | Params | Resolution | Dataset | Weights |
|---|---|---|---|---|---|
| I-JEPA ViT-H/14 | ViT-Huge, patch 14 | 632M | 224x224 | ImageNet-1K | Download | HuggingFace |
| I-JEPA ViT-H/16-448 | ViT-Huge, patch 16 | 632M | 448x448 | ImageNet-1K | Download | HuggingFace |
| I-JEPA ViT-H/14 | ViT-Huge, patch 14 | 632M | 224x224 | ImageNet-22K | Download |
| I-JEPA ViT-G/16 | ViT-Giant, patch 16 | 1.0B | 224x224 | ImageNet-22K | Download |
| V-JEPA ViT-L/16 | ViT-Large, patch 16 | 304M | 224x224 | VideoMix2M | Download |
| V-JEPA ViT-H/16 | ViT-Huge, patch 16 | 632M | 224x224 | VideoMix2M | Download |
# Cargo.toml
[dependencies]
jepa-core = "0.1.0"
jepa-vision = "0.1.0"
jepa-compat = "0.1.0" # For ONNX + checkpoint loadingThe jepa binary provides a unified CLI for the workspace:
# Install the CLI from crates.io
cargo install jepa
# Or install from the local workspace checkout
cargo install --path crates/jepa
# Launch the interactive TUI dashboard
jepa
# List pretrained models in the registry
jepa models
# Inspect a safetensors checkpoint
jepa inspect model.safetensors
# Analyze checkpoint with key remapping
jepa checkpoint model.safetensors --keymap ijepa --verbose
# Launch a training run
jepa train --preset vit-base-16 --steps 10 --batch-size 1 --lr 1e-3
# Train from a normal image directory tree with deterministic resize/crop/normalize
jepa train --preset vit-base-16 --steps 100 --batch-size 4 \
--dataset-dir ./images/train --resize 256 --crop-size 224 --shuffle
# Train from a safetensors image tensor dataset [N, C, H, W]
jepa train --preset vit-base-16 --steps 100 --batch-size 1 \
--dataset train.safetensors --dataset-key images
# Encode inputs through a safetensors checkpoint
jepa encode --model model.safetensors --preset vit-base-16
# Or through an ONNX model
jepa encode --model model.onnx --height 224 --width 224The CLI train command now runs real strict masked-image optimization with
AdamW and EMA. It chooses one input source per run:
--dataset-dir <PATH>for a recursive image-folder dataset (jpg,jpeg,png,webp) with decode, RGB conversion, shorter-side resize, center crop, CHW tensor conversion, and normalization--dataset <FILE> --dataset-key <KEY>for a safetensors image tensor shaped[N, C, H, W]- no dataset flags for the synthetic random-tensor fallback
Image-folder preprocessing defaults to the preset image size for --crop-size
and the ImageNet RGB normalization statistics when --mean and --std are
omitted. Dataset loading is currently single-threaded.
jepa encode executes real encoder weights for .safetensors and .onnx
inputs; other extensions still fall back to the preset demo path.
The jepa crate now ships runnable examples under
crates/jepa/examples/ that exercise the real training command instead of
mocking the CLI path:
# Create a tiny recursive image-folder dataset under target/example-data/jepa/
cargo run -p jepa --example prepare_demo_image_folder
# Train for 2 steps on that generated image-folder dataset
cargo run -p jepa --example train_image_folder_demo
# Train for 2 steps with the synthetic fallback path
cargo run -p jepa --example train_synthetic_demoThe image-folder example deliberately uses a very small generated dataset (6 PNG files across nested subdirectories). That is enough for a meaningful smoke demo of recursive dataset discovery, decode, resize, crop, normalize, batching, masking, optimizer updates, and EMA without checking a large image corpus into git. It is not large enough to demonstrate real representation learning quality; it is an execution demo, not a benchmark dataset.
The TUI now incorporates these demos in the Training tab as a guided demo
runner. Launch jepa, switch to tab 3, choose a demo with j/k, and press
Enter to run it. The panel streams real run logs, step metrics, loss/energy
charts, and a short interpretation of what happened.
The TUI Inference tab on 4 adds a separate guided walkthrough for encoder
inference. It runs deterministic demo image patterns through a preset ViT,
streams phase changes, per-sample latency and embedding statistics, and explains
what the representation telemetry means. The walkthrough is intentionally a
pipeline demo rather than a pretrained semantic benchmark.
If you want to run the CLI directly after generating the demo dataset:
cargo run -p jepa -- train --preset vit-small-16 --steps 2 --batch-size 2 \
--dataset-dir target/example-data/jepa/demo-image-folder \
--resize 256 --crop-size 224 --shuffle --dataset-limit 6use jepa_compat::safetensors::load_checkpoint;
use jepa_compat::keymap::ijepa_vit_keymap;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let mappings = ijepa_vit_keymap();
let checkpoint = load_checkpoint("model.safetensors", &mappings)?;
println!("Loaded {} tensors", checkpoint.len());
for key in checkpoint.keys() {
println!(" {}: {:?}", key, checkpoint.get(key).unwrap().shape);
}
Ok(())
}use burn::prelude::*;
use burn_ndarray::NdArray;
use jepa_core::masking::{BlockMasking, MaskingStrategy};
use jepa_core::types::InputShape;
use jepa_vision::image::IJepaConfig;
use jepa_vision::vit::VitConfig;
type B = NdArray<f32>;
fn main() {
let device = burn_ndarray::NdArrayDevice::Cpu;
// Configure I-JEPA with ViT-Huge/14 (matches Facebook pretrained)
let config = IJepaConfig {
encoder: VitConfig::vit_huge_patch14(),
predictor: jepa_vision::image::TransformerPredictorConfig {
encoder_embed_dim: 1280,
predictor_embed_dim: 384,
num_layers: 12,
num_heads: 12,
max_target_len: 256,
},
};
let model = config.init::<B>(&device);
// Generate masks (I-JEPA block masking)
let shape = InputShape::Image { height: 16, width: 16 }; // 224/14 = 16
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let masking = BlockMasking {
num_targets: 4,
target_scale: (0.15, 0.2),
target_aspect_ratio: (0.75, 1.5),
};
let mask = masking.generate_mask(&shape, &mut rng);
println!("Context tokens: {}, Target tokens: {}",
mask.context_indices.len(), mask.target_indices.len());
}use jepa_compat::registry::{list_models, find_model};
fn main() {
for model in list_models() {
println!("{}: {} ({}, {})",
model.name,
model.param_count_human(),
model.architecture,
model.pretrained_on);
}
// Search for a specific model
if let Some(m) = find_model("vit-h/14") {
println!("\nFound: {} with {} patches",
m.name, m.num_patches());
}
}jepa-rs/
├── jepa-core Core traits, tensor wrappers, masking, energy, EMA
│ ├── Encoder Trait for context/target encoders
│ ├── Predictor Trait for latent predictors
│ ├── EnergyFn L2, Cosine, SmoothL1 energy functions
│ ├── MaskingStrategy Block, MultiBlock, Spatiotemporal, Object masking
│ ├── CollapseReg VICReg, BarlowTwins collapse prevention
│ └── EMA Exponential moving average with cosine schedule
│
├── jepa-vision Vision transformers and JEPA models
│ ├── VitEncoder ViT-S/B/L/H/G with 2D RoPE
│ ├── IJepa I-JEPA pipeline (image)
│ ├── VJepa V-JEPA pipeline (video, 3D tubelets)
│ ├── SlotAttention Slot attention encoder for object-centric representations
│ └── Predictor Transformer-based cross-attention predictor
│
├── jepa-world World models and planning
│ ├── ActionPredictor Action-conditioned latent prediction
│ ├── ObjectDynamics Transformer dynamics predictor for object slots
│ ├── Planner Random shooting planner with cost functions
│ ├── HierarchicalJepa Multi-level H-JEPA
│ └── ShortTermMemory Sliding-window memory for temporal context
│
├── jepa-train Training orchestration
│ ├── TrainConfig Learning rate schedules, EMA config
│ ├── JepaComponents Generic forward step orchestration
│ ├── CausalJepa C-JEPA training loop (frozen encoder, object masking)
│ └── CheckpointMeta Save/resume metadata
│
├── jepa-compat Model compatibility and interop
│ ├── ModelRegistry Pretrained model catalog (Facebook Research)
│ ├── SafeTensors Load .safetensors checkpoints
│ ├── KeyMap PyTorch → burn key remapping
│ └── OnnxModelInfo ONNX metadata inspection and initializer loading
│
├── jepa CLI and interactive TUI dashboard
│ ├── CLI models, inspect, checkpoint, train, encode commands
│ └── TUI Dashboard, Models, Training, Inference, Checkpoint, About tabs
│
└── jepa-web Browser demo crate
├── WASM API JS-callable training and inference helpers
├── Demo UI HTML, JS, and CSS assets for local browser demos
└── Backend status CPU-backed path today; WebGPU scaffolding remains internal
All tensor-bearing APIs are generic over B: Backend, allowing transparent execution on CPU (NdArray), GPU (WGPU), or WebAssembly backends.
jepa-rs provides ONNX metadata inspection and initializer loading through jepa-compat. This allows inspecting model structure, input/output specs, and importing weight initializers from .onnx files.
Current scope: metadata inspection and weight import are production-ready. Tract-based ONNX graph execution exists (OnnxSession, OnnxEncoder) but is not yet production-grade — it is functional for prototyping and testing.
| Example | Description | Run command |
|---|---|---|
jepa |
Interactive TUI dashboard | cargo run -p jepa |
jepa models |
Browse pretrained model registry | cargo run -p jepa -- models |
jepa train |
Launch a training run | cargo run -p jepa -- train --preset vit-base-16 |
prepare_demo_image_folder |
Generate a tiny recursive dataset for --dataset-dir demos |
cargo run -p jepa --example prepare_demo_image_folder |
train_image_folder_demo |
Run the real jepa train image-folder path on generated images |
cargo run -p jepa --example train_image_folder_demo |
train_synthetic_demo |
Run the real jepa train synthetic fallback path |
cargo run -p jepa --example train_synthetic_demo |
ijepa_demo |
Full I-JEPA forward pass pipeline | cargo run -p jepa-vision --example ijepa_demo |
ijepa_train_loop |
Training loop with metrics | cargo run -p jepa-vision --example ijepa_train_loop |
world_model_planning |
World model with random shooting | cargo run -p jepa-world --example world_model_planning |
model_registry |
Browse pretrained models (library) | cargo run -p jepa-compat --example model_registry |
The browser demo lives in crates/jepa-web. Its exported WASM path currently
uses the deterministic CPU backend for reliable local demos and tests; see the
architecture and production-gap docs below before treating it as a WebGPU-ready
surface.
# Build everything
cargo build --workspace
# Run all tests
cargo test --workspace
# Lint
cargo clippy --workspace --all-targets -- -D warnings
# Format check
cargo fmt -- --check
# Generate docs
cargo doc --workspace --no-deps --open
# Run differential parity tests
scripts/run_parity_suite.sh
# Target a single crate
cargo test -p jepa-core
cargo test -p jepa-vision
cargo test -p jepa-compat
cargo test -p jepa-web# Code coverage (requires cargo-llvm-cov)
cargo llvm-cov --workspace --all-features --fail-under-lines 80
# Fuzz testing (requires cargo-fuzz)
(cd fuzz && cargo fuzz run masking -- -runs=1000)
# Benchmark smoke test
cargo bench --workspace --no-run- docs/ARCHITECTURE.md - crate boundaries, data flow, and invariants
- docs/QUALITY_GATES.md - exact local and release verification commands
- docs/RELEASE.md - current release checklist and publish constraints
- docs/ROADMAP.md - next milestones with explicit exit criteria
- docs/PRODUCTION_GAPS.md - honest register of remaining blockers and risks
As of 2026-03-16, this project is alpha. It is suitable for research, local demos, and extension work; it is not yet suitable for unqualified production deployment of every advertised surface. In particular, strict video parity is still pending, ONNX graph execution remains a prototype path, and the browser demo uses the CPU-backed WASM path today.
- Complete I-JEPA, V-JEPA, and C-JEPA architectures with strict masked-encoder paths
- CLI with 6 commands (
models,inspect,checkpoint,train,encode,tui) - Interactive TUI dashboard with 6 tabs (Dashboard, Models, Training, Inference, Checkpoint, About)
- Browser demo crate with deterministic CPU-backed WASM training and inference for local experimentation
- SafeTensors checkpoint loading with automatic key remapping
- ONNX metadata inspection and initializer loading
- Pretrained model registry with download URLs
- Differential parity tests against 3 checked-in strict image fixtures
- Comprehensive test suite (500+ tests), property-based testing, fuzz targets
- All standard ViT configs: ViT-S/16, ViT-B/16, ViT-L/16, ViT-H/14, ViT-H/16, ViT-G/16
- The generic trainer slices tokens after encoder forward; strict pre-attention masking is available via
IJepa::forward_step_strictandVJepa::forward_step_strict - ONNX support covers metadata inspection and initializer loading only, not graph execution
- The
jepa-webcrate keepsburn-wgpuscaffolding in-tree, but the exported browser demo currently runs on theburn-ndarrayCPU backend only - Differential parity runs in CI for strict image fixtures; broader video parity is pending
- First-time crates.io release must be published in dependency order because the workspace crates depend on each other by version
The JEPA family has grown across several papers. Here is exactly what jepa-rs implements and how each component maps to a specific paper and reference codebase.
| Paper | Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture (Assran et al., CVPR 2023) |
| Reference code | facebookresearch/ijepa (archived) |
| jepa-rs struct | IJepa<B> in jepa-vision (crates/jepa-vision/src/image.rs) |
| What it does | Self-supervised image representation learning. A ViT context-encoder sees only visible patches; a lightweight predictor predicts representations of masked target patches. The target-encoder is an EMA copy of the context-encoder. |
| Masking | BlockMasking — contiguous rectangular blocks on the 2D patch grid. |
| Faithful path | IJepa::forward_step_strict — filters tokens before encoder self-attention (matches the paper). |
| Approximate path | JepaComponents::forward_step in jepa-train — encodes full input then slices (post-encoder masking; cheaper but not faithful). |
| Parity status | 3 checked-in strict image fixtures verified in CI. |
| Paper | Revisiting Feature Prediction for Learning Visual Representations from Video (Bardes et al., 2024) |
| Reference code | facebookresearch/jepa |
| jepa-rs struct | VJepa<B> in jepa-vision (crates/jepa-vision/src/video.rs) |
| What it does | Extends I-JEPA to video. A ViT encoder processes 3D tubelets (space + time) with 3D RoPE. |
| Masking | SpatiotemporalMasking — contiguous 3D regions in the spatiotemporal grid. |
| Faithful path | VJepa::forward_step_strict — pre-attention masking. |
| Parity status | Implemented but strict video parity not yet proven (pending). |
| Paper | V-JEPA 2: Self-Supervised Video Models Enable Understanding, Prediction and Planning (Bardes et al., 2025) |
| Reference code | facebookresearch/vjepa2 |
| jepa-rs support | Not a separate struct. The VJepa<B> struct can be configured with V-JEPA 2 features. |
| What we take from V-JEPA 2 | Cosine momentum schedule for EMA — CosineMomentumSchedule in jepa-core (Ema::with_cosine_schedule). Momentum ramps from base (e.g. 0.996) to 1.0 over training. Also: MultiBlockMasking strategy, ViT-Giant/14 preset. |
| What we don't implement | The full V-JEPA 2 training recipe, attentive probing, or the planning/action heads from the paper. |
| Paper | Inspired by A Path Towards Autonomous Machine Intelligence (LeCun, 2022) — the original JEPA position paper describes hierarchical prediction as a long-term goal. No standalone H-JEPA paper exists yet. |
| jepa-rs struct | HierarchicalJepa<B> in jepa-world (crates/jepa-world/src/hierarchy.rs) |
| What it does | Stacks multiple JEPA levels at different temporal strides (e.g. stride 2, 6, 24). Each level has its own encoder and predictor. This is experimental — no reference implementation exists. |
| Paper | Draws from both the LeCun position paper and V-JEPA 2 (planning component). |
| jepa-rs structs | Action<B>, ActionConditionedPredictor<B> trait, RandomShootingPlanner in jepa-world (crates/jepa-world/src/action.rs, crates/jepa-world/src/planner.rs) |
| What it does | Predicts next-state representations given current state + action. Supports random-shooting (CEM) planning. This is experimental. |
| Paper | Causal-JEPA: Learning World Models through Object-Level Latent Interventions (Nam et al., 2025) |
| Reference code | galilai-group/cjepa |
| jepa-rs structs | ObjectMasking in jepa-core, SlotAttention<B> / SlotEncoder<B> in jepa-vision, CausalJepaComponents in jepa-train, ObjectDynamicsPredictor<B> in jepa-world |
| What it does | Object-centric JEPA: masks whole objects (not patches), uses a frozen encoder with slot attention, identity-anchored masked tokens, and joint history + future MSE loss. Enables causal reasoning and efficient CEM planning in object-representation space (~98% token reduction vs patch-based models). |
| Masking | ObjectMasking — randomly partitions N object slots into context/target subsets. |
| Training | Frozen encoder (no EMA), only slot attention and predictor are trained. CausalJepaComponents::forward_step in jepa-train. |
| Planning | ObjectDynamicsPredictor + RandomShootingPlanner in jepa-world for CEM-based MPC in object space. |
EB-JEPA (Terver et al., 2026) is a separate lightweight Python library for energy-based JEPA. jepa-rs is not an implementation of EB-JEPA. We reference it for comparison only. The energy functions in jepa-core (L2, Cosine, SmoothL1) are standard loss formulations, not the EB-JEPA energy framework.
| Variant | Paper | jepa-rs struct | Status |
|---|---|---|---|
| I-JEPA | Assran et al. 2023 | IJepa<B> |
Strict path implemented, parity verified |
| V-JEPA | Bardes et al. 2024 | VJepa<B> |
Strict path implemented, parity pending |
| V-JEPA 2 | Bardes et al. 2025 | VJepa<B> + cosine EMA schedule |
Select features only |
| H-JEPA | LeCun 2022 (position paper) | HierarchicalJepa<B> |
Experimental, no reference impl |
| C-JEPA | Nam et al. 2025 | ObjectMasking, SlotAttention, CausalJepaComponents, ObjectDynamicsPredictor |
Core support implemented |
| World model | LeCun 2022 + V-JEPA 2 | ActionConditionedPredictor, RandomShootingPlanner |
Experimental |
| EB-JEPA | Terver et al. 2026 | Not implemented | Referenced for comparison only |
| Paper | Focus |
|---|---|
| A Path Towards Autonomous Machine Intelligence | JEPA position paper — hierarchical world models (LeCun, 2022) |
| I-JEPA | Self-supervised image learning with masked prediction in latent space (Assran et al., CVPR 2023) |
| V-JEPA | Extension to video with spatiotemporal masking (Bardes et al., 2024) |
| V-JEPA 2 | Video understanding, prediction, and planning (Bardes et al., 2025) |
| C-JEPA | Object-centric world models via latent interventions (Nam et al., 2025) |
| Slot Attention | Object-centric learning with slot attention (Locatello et al., NeurIPS 2020) |
| EB-JEPA | Lightweight energy-based JEPA library — referenced for comparison (Terver et al., 2026) |
| Repo | Models | Relationship to jepa-rs |
|---|---|---|
facebookresearch/ijepa |
I-JEPA (archived) | Primary reference for IJepa<B> and key remapping |
facebookresearch/jepa |
V-JEPA | Primary reference for VJepa<B> |
facebookresearch/vjepa2 |
V-JEPA 2 | Reference for cosine EMA schedule, ViT-G config |
galilai-group/cjepa |
C-JEPA | Primary reference for ObjectMasking, SlotAttention, CausalJepaComponents |
facebookresearch/eb_jepa |
EB-JEPA tutorial | Not implemented — comparison only |
See CONTRIBUTING.md for guidelines.
MIT License. See LICENSE for details.