Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9da01fa
full training infra implementation
alexzms Mar 8, 2026
c20b025
remove sampler_kind scheduler kind
alexzms Mar 8, 2026
36af66f
scrips and yamls
alexzms Mar 8, 2026
27fd7c1
fix ema with fsdp
alexzms Mar 8, 2026
cb3f85c
fix wandb tracker
alexzms Mar 8, 2026
7427510
no redundant predict x0 in model specific impl
alexzms Mar 8, 2026
1dceae4
fix validation dmd timestep missing
alexzms Mar 8, 2026
7353a4f
causal wan init impl
alexzms Mar 8, 2026
05b40f8
move dev file position
alexzms Mar 8, 2026
57a0ee9
Merge branch 'main' of https://github.com/hao-ai-lab/FastVideo into t…
jzhang38 Mar 8, 2026
4a56625
mv design to docs
jzhang38 Mar 8, 2026
9abae81
update ema
jzhang38 Mar 8, 2026
6ebcbdc
fix generator and reproducibility
jzhang38 Mar 8, 2026
86dcdad
testing self forcing
alexzms Mar 8, 2026
3898fbc
training infra doc init impl
alexzms Mar 8, 2026
36bff05
generic cli override of config. remove seperate resume
alexzms Mar 8, 2026
aff3174
ode init conversion, deterministic unset
alexzms Mar 8, 2026
0e6d8e9
remove deprecated rollout_mode param
alexzms Mar 9, 2026
cfef92b
update agents
jzhang38 Mar 9, 2026
498a712
validation remove deprecated dmd_steps
alexzms Mar 9, 2026
ad4622a
fix ckpt resuming
alexzms Mar 9, 2026
0a0dc8f
Merge branch 'train-clean-refactor' of https://github.com/FoundationR…
jzhang38 Mar 9, 2026
5409a6a
Merge branch 'hao-ai-lab:main' into train-clean-refactor
alexzms Mar 9, 2026
4512ad5
revert unused randomstate wrapper
alexzms Mar 9, 2026
c32f616
precommit
alexzms Mar 9, 2026
f222068
slurm run script
alexzms Mar 9, 2026
2449cd7
minor
alexzms Mar 9, 2026
17f1e38
+ resume from checkpoint latest
jzhang38 Mar 9, 2026
b3ef4e0
Merge branch 'train-clean-refactor' of https://github.com/FoundationR…
jzhang38 Mar 9, 2026
680bc8e
precommit
jzhang38 Mar 9, 2026
0674a89
Merge remote-tracking branch 'origin/train-clean-refactor' into train…
alexzms Mar 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions .agents/memory/codebase-map/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# FastVideo-WorldModel — Codebase Map

High-level structural index for agent orientation. Updated 2026-03-02.
High-level structural index for agent orientation. Updated 2026-03-08.

## Repository Layout

Expand All @@ -22,7 +22,18 @@ FastVideo-WorldModel/
│ ├── pipelines/ # End-to-end pipelines
│ │ ├── basic/ # Per-model pipelines (wan/, ltx2/, ...)
│ │ └── stages/ # Reusable pipeline stages
│ ├── training/ # Training infrastructure
│ ├── train/ # Refactored training framework (YAML-driven, preferred)
│ │ ├── trainer.py # Main training loop coordinator
│ │ ├── entrypoint/ # Training entrypoint (train.py) + checkpoint conversion
│ │ ├── methods/ # Training algorithms (FineTune, DFSFT, DMD2, SelfForcing)
│ │ │ ├── base.py # TrainingMethod ABC
│ │ │ ├── fine_tuning/ # FineTuneMethod, DiffusionForcingSFTMethod
│ │ │ └── distribution_matching/ # DMD2Method, SelfForcingMethod
│ │ ├── models/ # Per-role model wrappers (ModelBase, CausalModelBase)
│ │ │ └── wan/ # WanModel, WanCausalModel
│ │ ├── callbacks/ # Composable hooks (grad_clip, ema, validation)
│ │ └── utils/ # Config, builder, checkpoint, optimizer, tracking
│ ├── training/ # Legacy training infrastructure (being phased out)
│ │ ├── trackers.py # W&B tracker (BaseTracker → WandbTracker)
│ │ ├── training_utils.py # Checkpointing, grad clipping, state dicts
│ │ ├── training_pipeline.py # Base training pipeline
Expand Down Expand Up @@ -65,6 +76,17 @@ FastVideo-WorldModel/

## Key Training Entrypoints

### New framework (`fastvideo/train/`) — preferred

| Method | Config Example | Launch Pattern |
|--------|---------------|----------------|
| FineTune (Wan) | `examples/train/finetune_wan2.1_t2v_1.3B_vsa_*.yaml` | `torchrun -m fastvideo.train.entrypoint.train --config <yaml>` |
| DFSFT (Wan causal) | `examples/train/dfsft_wan_causal_t2v_1.3B.yaml` | `torchrun -m fastvideo.train.entrypoint.train --config <yaml>` |
| DMD2 distillation | `examples/train/distill_wan2.1_t2v_1.3B_dmd2.yaml` | `torchrun -m fastvideo.train.entrypoint.train --config <yaml>` |
| Self-Forcing | `examples/train/self_forcing_wan_causal_t2v_1.3B.yaml` | `torchrun -m fastvideo.train.entrypoint.train --config <yaml>` |

### Legacy pipelines (`fastvideo/training/`) — being phased out

| Pipeline | Entrypoint | Launch Pattern |
|----------|-----------|----------------|
| Wan T2V finetune | `fastvideo/training/wan_training_pipeline.py` | `torchrun --nproc_per_node N` |
Expand Down
2 changes: 1 addition & 1 deletion .agents/memory/evaluation-registry/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ A comprehensive benchmark examining **four critical capabilities**:
**Priority**: **Highest** — this is the most important evaluation signal
**Trust**: Highest — but expensive

#### What It Measures
### What It Measures
Human evaluators compare generated videos and rate them on dimensions like:
- Overall quality and realism
- Temporal coherence and smoothness
Expand Down
7 changes: 4 additions & 3 deletions .agents/onboarding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ Read these files to build your context:
|----------|------|----------------|
| 1 | `AGENTS.md` | Coding guidelines, build/test commands, PR conventions |
| 2 | `docs/design/overview.md` | Architecture: models, pipelines, configs, registry |
| 3 | `docs/training/overview.md` | Training data flow and preprocessing |
| 4 | `docs/training/finetune.md` | Training arguments, parallelism, LoRA, validation |
| 5 | `docs/contributing/coding_agents.md` | How to add model pipelines with agent assistance |
| 3 | `fastvideo/train/` | Refactored training framework (YAML-driven, modular methods/models/callbacks) |
| 4 | `docs/training/overview.md` | Training data flow and preprocessing |
| 5 | `docs/training/finetune.md` | Training arguments, parallelism, LoRA, validation |
| 6 | `docs/contributing/coding_agents.md` | How to add model pipelines with agent assistance |

## Step 2: Discover Available Resources

Expand Down
260 changes: 196 additions & 64 deletions .agents/onboarding/worldmodel-training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,25 @@ auto-regressive streaming generation.
- Full finetuning and LoRA on Wan / LTX-2 / MatrixGame models
- DMD-based distillation (few-step generation)
- Self-Forcing distillation (causal streaming)
- Consistency finetuning with causal ODE initialization
- Action injection modules for interactive control
- Diffusion-Forcing SFT (DFSFT) for causal models
- VSA (Variable Sparsity Acceleration) for efficient training

---

## Training Code: Two Generations

### New modular framework: `fastvideo/train/` (preferred)

The refactored training code uses a **YAML-only config-driven** architecture
with composable methods, per-role models, and a callback system. All new
training work should use this framework.

### Legacy pipelines: `fastvideo/training/` (deprecated)

The old monolithic pipeline classes (`WanTrainingPipeline`,
`DistillationPipeline`, etc.) still exist but are being phased out. The new
framework imports select utilities from `fastvideo/training/` for backward
compatibility (EMA, gradient clipping, checkpoint wrappers).

---

Expand All @@ -29,74 +46,202 @@ Read these **in order** before touching any training code:
| # | File | What You Learn |
|---|------|----------------|
| 1 | `docs/training/overview.md` | Training data flow: raw video → text embeddings + video latents → training |
| 2 | `docs/training/finetune.md` | All training arguments, parallelism (SP/TP), LoRA, validation settings |
| 2 | `docs/training/finetune.md` | Training arguments, parallelism (SP/TP), LoRA, validation settings |
| 3 | `docs/training/data_preprocess.md` | How to preprocess datasets into the expected format |
| 4 | `docs/design/overview.md` | Architecture: models, pipelines, configs, registry |

---

## Training Pipelines
## New Training Framework (`fastvideo/train/`)

Each pipeline is a Python entrypoint launched via `torchrun`:
### Architecture Overview

| Pipeline | Entrypoint | Use Case |
|----------|-----------|----------|
| **Wan T2V finetune** | `fastvideo/training/wan_training_pipeline.py` | Standard text-to-video full finetune / LoRA |
| **Wan I2V finetune** | `fastvideo/training/wan_i2v_training_pipeline.py` | Image-to-video (condition on first frame) |
| **MatrixGame finetune** | `fastvideo/training/matrixgame_training_pipeline.py` | Action-conditioned world model training |
| **LTX-2 finetune** | `fastvideo/training/ltx2_training_pipeline.py` | LTX-2 architecture finetuning |
| **Wan DMD distillation** | `fastvideo/training/wan_distillation_pipeline.py` | Few-step distillation via DMD |
| **Self-Forcing distill** | `fastvideo/training/wan_self_forcing_distillation_pipeline.py` | Causal streaming distillation |
| **Base training** | `fastvideo/training/training_pipeline.py` | Base class — not called directly |
```
fastvideo/train/
├── __init__.py → exports Trainer
├── trainer.py → main training loop coordinator
├── entrypoint/
│ ├── train.py → YAML-only training entrypoint
│ └── dcp_to_diffusers.py → checkpoint conversion utility
├── methods/ → training algorithms (TrainingMethod ABC)
│ ├── base.py → TrainingMethod base class
│ ├── fine_tuning/
│ │ ├── finetune.py → FineTuneMethod (supervised finetuning)
│ │ └── dfsft.py → DiffusionForcingSFTMethod (causal)
│ ├── distribution_matching/
│ │ ├── dmd2.py → DMD2Method (distribution matching distill)
│ │ └── self_forcing.py → SelfForcingMethod (causal streaming)
│ ├── knowledge_distillation/ → (stub, not yet implemented)
│ └── consistency_model/ → (stub, not yet implemented)
├── models/ → per-role model instances
│ ├── base.py → ModelBase & CausalModelBase (ABC)
│ └── wan/
│ ├── wan.py → WanModel (non-causal)
│ └── wan_causal.py → WanCausalModel (causal streaming)
├── callbacks/ → training hooks & monitoring
│ ├── callback.py → Callback base class + CallbackDict
│ ├── grad_clip.py → GradNormClipCallback
│ ├── ema.py → EMACallback (shadow weights)
│ └── validation.py → ValidationCallback (sampling + eval)
└── utils/ → configuration, building, checkpointing
├── builder.py → build_from_config() (config → runtime)
├── checkpoint.py → CheckpointManager (DCP-based)
├── config.py → load_run_config() (YAML → RunConfig)
├── training_config.py → TypedConfig dataclasses
├── optimizer.py → build_optimizer_and_scheduler()
├── instantiate.py → resolve_target() + instantiate()
├── tracking.py → build_tracker() (W&B, etc.)
├── dataloader.py → dataloader utilities
├── module_state.py → apply_trainable()
└── moduleloader.py → load_module_from_path()
```

---
### Key Concepts

## Example Scripts
**TrainingMethod** (`methods/base.py`): Abstract base class for all training
algorithms. Owns role models (student, teacher, critic), manages checkpoint
state, and defines the training step interface.

Ready-to-run training launches. Use these as starting templates:
**ModelBase** (`models/base.py`): Per-role model wrapper. Each role (student,
teacher, critic) gets its own `ModelBase` instance owning a `transformer` and
`noise_scheduler`. `CausalModelBase` extends this for streaming models.

### Finetuning
| Model | Script | Notes |
|-------|--------|-------|
| Wan T2V 1.3B | `examples/training/finetune/wan_t2v_1.3B/crush_smol/finetune_t2v.sh` | Smallest, fastest for testing |
| Wan T2V 1.3B LoRA | `examples/training/finetune/wan_t2v_1.3B/crush_smol/finetune_t2v_lora.sh` | Lightweight adapter |
| Wan I2V 14B | `examples/training/finetune/wan_i2v_14B_480p/crush_smol/finetune_i2v.sh` | Large I2V model |
| LTX-2 | `examples/training/finetune/ltx2/finetune_t2v.sh` | Alternative architecture |
| MatrixGame 2.0 | `examples/training/finetune/MatrixGame2.0/finetune_i2v.sh` | Action-conditioned |
**Callback system** (`callbacks/`): Composable hooks for gradient clipping,
EMA, validation, etc. Configured via YAML, dispatched by `CallbackDict`.

### Distillation
| Method | Script | Notes |
|--------|--------|-------|
| DMD Wan | `scripts/distill/v1_distill_dmd_wan.sh` | Full distillation launch |
| DMD Wan + VSA | `scripts/distill/v1_distill_dmd_wan_VSA.sh` | With variable-step acceleration |
| Consistency (causal ODE init) | `examples/training/consistency_finetune/causal_ode_init/finetune_ode_init.sh` | Consistency tuning |
**Config system** (`utils/config.py`, `utils/training_config.py`): YAML files
are parsed into typed `RunConfig` dataclass trees. Models and methods use
`_target_` fields for instantiation (similar to Hydra).

### Training Flow

```
run_training_from_config(config_path)
→ load_run_config() # YAML → RunConfig
→ init_distributed() # TP/SP setup
→ build_from_config() # instantiate models, method, dataloader
→ Trainer.run() # main loop:
├─ callbacks.on_train_start()
├─ checkpoint_manager.maybe_resume()
├─ for step in range(max_steps):
│ ├─ method.single_train_step(batch)
│ ├─ method.backward()
│ ├─ callbacks.on_before_optimizer_step()
│ ├─ method.optimizers_schedulers_step()
│ ├─ tracker.log(metrics, step)
│ ├─ callbacks.on_training_step_end()
│ └─ checkpoint_manager.maybe_save(step)
├─ callbacks.on_train_end()
└─ checkpoint_manager.save_final()
```

### Training Methods

| Method | Class | Use Case |
|--------|-------|----------|
| **FineTune** | `FineTuneMethod` | Single-role supervised finetuning |
| **DFSFT** | `DiffusionForcingSFTMethod` | Diffusion-forcing SFT with inhomogeneous timesteps |
| **DMD2** | `DMD2Method` | Multi-role distribution matching distillation (student + teacher + critic) |
| **Self-Forcing** | `SelfForcingMethod` | Extends DMD2 for causal student rollouts |

### Launching Training (New Framework)

Training is launched via `torchrun` with a single YAML config:

### Data Preprocessing
Each model directory includes a `preprocess_*` script. Always preprocess first:
```bash
# Example for Wan T2V
bash examples/training/finetune/wan_t2v_1.3B/crush_smol/preprocess_wan_data_t2v.sh
torchrun --nproc_per_node <N_GPUS> \
-m fastvideo.train.entrypoint.train \
--config examples/train/<config>.yaml
```

### Example YAML Configs

| Config | Method | Description |
|--------|--------|-------------|
| `examples/train/finetune_wan2.1_t2v_1.3B_vsa_phase3.4_0.9sparsity.yaml` | FineTune | Wan 1.3B finetuning with VSA sparsity |
| `examples/train/distill_wan2.1_t2v_1.3B_dmd2.yaml` | DMD2 | Wan 1.3B distillation (student + teacher + critic) |
| `examples/train/dfsft_wan_causal_t2v_1.3B.yaml` | DFSFT | Causal Wan 1.3B diffusion-forcing SFT |
| `examples/train/self_forcing_wan_causal_t2v_1.3B.yaml` | Self-Forcing | Causal streaming distillation |

### Checkpointing (New Framework)

**CheckpointManager** (`utils/checkpoint.py`) saves via `torch.distributed.checkpoint`:

```
output_dir/
└─ checkpoint-{step}/
├─ dcp/ # DCP state dict
├─ config.json # resolved training config
└─ .fastvideo_metadata.json
```

Checkpoint state includes: role model weights, per-role optimizers/schedulers,
CUDA RNG state, and callback state (e.g., EMA shadow weights).

### Config Structure

A YAML config defines the full training pipeline:

```yaml
models:
student:
_target_: fastvideo.train.models.wan.WanModel
model_path: ...
trainable: true
teacher: # optional, for distillation
_target_: fastvideo.train.models.wan.WanModel
model_path: ...
trainable: false

method:
_target_: fastvideo.train.methods.fine_tuning.FineTuneMethod
# method-specific params...

training:
distributed: { num_gpus: 8, tp_size: 1, sp_size: 8 }
data: { data_path: ..., batch_size: 1 }
optimizer: { lr: 1e-5, lr_scheduler: constant_with_warmup }
loop: { max_train_steps: 1000 }
checkpoint: { output_dir: ./outputs }
tracker: { trackers: [wandb], project_name: ... }

callbacks:
grad_clip:
_target_: fastvideo.train.callbacks.GradNormClipCallback
max_grad_norm: 1.0
validation:
_target_: fastvideo.train.callbacks.ValidationCallback
validation_steps: 100
```

---

## Legacy Training Pipelines (`fastvideo/training/`)

> **Note:** Use the new `fastvideo/train/` framework for new work. This section
> is retained for reference on existing pipelines not yet migrated.

| Pipeline | Entrypoint | Use Case |
|----------|-----------|----------|
| Wan T2V finetune | `fastvideo/training/wan_training_pipeline.py` | Standard text-to-video finetune / LoRA |
| Wan I2V finetune | `fastvideo/training/wan_i2v_training_pipeline.py` | Image-to-video (first frame conditioned) |
| MatrixGame finetune | `fastvideo/training/matrixgame_training_pipeline.py` | Action-conditioned world model |
| LTX-2 finetune | `fastvideo/training/ltx2_training_pipeline.py` | LTX-2 architecture finetuning |
| Wan DMD distillation | `fastvideo/training/wan_distillation_pipeline.py` | Few-step distillation via DMD |
| Self-Forcing distill | `fastvideo/training/wan_self_forcing_distillation_pipeline.py` | Causal streaming distillation |

---

## Key Infrastructure

### W&B Integration
- **Tracker**: `fastvideo/training/trackers.py` — `WandbTracker` class
- **New framework tracker**: `fastvideo/train/utils/tracking.py` — `build_tracker()`
- **Env vars**: `WANDB_API_KEY`, `WANDB_BASE_URL`, `WANDB_MODE`
- **Summaries**: Saved to `<output_dir>/tracker/wandb/latest-run/files/wandb-summary.json`
- **Tests**: `fastvideo/tests/training/Vanilla/test_training_loss.py` (compares against reference summaries)

### Checkpointing
- **Save**: `fastvideo/training/training_utils.py:save_checkpoint()`
- **Load**: `fastvideo/training/training_utils.py:load_checkpoint()`
- **Distill save**: `training_utils.py:save_distillation_checkpoint()` (multi-model)
- **Format**: FSDP distributed checkpoint → converted to HF format

### Parallelism
- **SP** (Sequence Parallel): splits video frames across GPUs — `--sp_size N`
- **TP** (Tensor Parallel): splits model layers across GPUs — `--tp_size N`
- **SP** (Sequence Parallel): splits video frames across GPUs — `sp_size: N`
- **TP** (Tensor Parallel): splits model layers across GPUs — `tp_size: N`
- Typical configs: SP=2–8, TP=1–2

---
Expand Down Expand Up @@ -148,23 +293,10 @@ their own error accumulation during long auto-regressive generation.

### DMD Distillation (Distribution Matching Distillation)
Reduces inference steps from ~50 to 3–4 by training a student model to match
the output distribution of the teacher model. Uses ODE pairs collected from
the teacher model as training targets.

---

## Quick-Start: Minimal Training Test
the output distribution of the teacher model. Uses a critic network to estimate
distribution divergence.

To verify the training infrastructure works, run the smallest possible experiment:

```bash
# 1. Download crush_smol dataset
bash examples/training/finetune/wan_t2v_1.3B/crush_smol/download_dataset.sh

# 2. Preprocess
bash examples/training/finetune/wan_t2v_1.3B/crush_smol/preprocess_wan_data_t2v.sh

# 3. Short training run (5 steps)
# Edit finetune_t2v.sh: set --max_train_steps 5
bash examples/training/finetune/wan_t2v_1.3B/crush_smol/finetune_t2v.sh
```
### Diffusion-Forcing SFT (DFSFT)
Supervised finetuning with **inhomogeneous timesteps** across chunks — each
chunk in a causal sequence can have a different noise level, training the model
to handle mixed-fidelity contexts.
2 changes: 2 additions & 0 deletions .agents/skills/SKILL_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ description: <one-line description — Codex uses this for implicit invocation m
- <What this skill produces>

## Example Usage

```
<Example invocation or prompt snippet>
```
Expand All @@ -50,6 +51,7 @@ Each skill lives in its own directory under `.agents/skills/`:
```

After creating a new skill, add an entry to `.agents/skills/index.jsonl`:

```json
{"name": "<skill-name>", "description": "<description>", "path": "<skill-name>/SKILL.md", "status": "draft", "trust": "low"}
```
1 change: 1 addition & 0 deletions .agents/skills/evaluate-video-quality/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pytest fastvideo/tests/ssim/ -vs --video-path <generated> --reference-path <refe
```

Or use the SSIM utility directly:

```python
from fastvideo.tests.ssim.ssim_utils import compute_ssim
score = compute_ssim(generated_video, reference_video)
Expand Down
Loading
Loading