High-level API examples for training, fine-tuning, and generating with DFM models.
| Task | Directory | Available Examples |
|---|---|---|
| Fine-tuning | finetune | • Wan 2.1 T2V: Fine-tuning with Flow Matching • Multi-node: Distributed training config |
| Generation | generate | • Generate: Run inference with Wan 2.1 • Validate: Run validation loop |
| Pre-training | pretrain | • Wan 2.1 T2V: Pre-training from scratch |
Train diffusion models with distributed training support using NeMo Automodel and flow matching.
Currently Supported: Wan 2.1 Text-to-Video (1.3B and 14B models)
# Build image
docker build -f docker/Dockerfile.ci -t dfm-training .
# Run container
docker run --gpus all -it \
-v $(pwd):/workspace \
-v /path/to/data:/data \
--ipc=host \
--ulimit memlock=-1 \
--ulimit stack=67108864 \
dfm-training bash
# Inside container: Initialize submodules
export UV_PROJECT_ENVIRONMENT=
git submodule update --init --recursive 3rdparty/Create video dataset:
<your_video_folder>/
├── video1.mp4
├── video2.mp4
└── meta.json
meta.json format:
[
{
"file_name": "video1.mp4",
"width": 1280,
"height": 720,
"start_frame": 0,
"end_frame": 121,
"vila_caption": "A detailed description of the video content..."
}
]Preprocess videos to .meta files:
There are two preprocessing modes:
Mode 1: Full video (recommended for training)
python dfm/src/automodel/utils/data/preprocess_resize.py \
--mode video \
--video_folder <your_video_folder> \
--output_folder ./processed_meta \
--model Wan-AI/Wan2.1-T2V-1.3B-Diffusers \
--height 480 \
--width 720 \
--center-cropMode 2: Extract frames (for frame-based training)
python dfm/src/automodel/utils/data/preprocess_resize.py \
--mode frames \
--num-frames 40 \
--video_folder <your_video_folder> \
--output_folder ./processed_frames \
--model Wan-AI/Wan2.1-T2V-1.3B-Diffusers \
--height 240 \
--width 416 \
--center-cropKey arguments:
--mode:video(full video) orframes(extract evenly-spaced frames)--num-frames: Number of frames to extract (only forframesmode)--height/--width: Target resolution--center-crop: Crop to exact size after aspect-preserving resize
Preprocessing modes:
videomode: Processes entire video sequence, creates one.metafile per videoframesmode: Extracts N evenly-spaced frames, creates one.metafile per frame (treated as 1-frame videos)
Output: Creates .meta files containing:
- Encoded video latents (normalized)
- Text embeddings (from UMT5)
- First frame as JPEG (video mode only)
- Metadata
Single-node (8 GPUs):
export UV_PROJECT_ENVIRONMENT=
uv run --group automodel --with . \
torchrun --nproc-per-node=8 \
examples/automodel/finetune/finetune.py \
-c examples/automodel/finetune/wan2_1_t2v_flow.yamlMulti-node with SLURM:
#!/bin/bash
#SBATCH -N 2
#SBATCH --ntasks-per-node 1
#SBATCH --gpus-per-node=8
#SBATCH --exclusive
export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
export MASTER_PORT=29500
export NUM_GPUS=8
# Per-rank UV cache to avoid conflicts
unset UV_PROJECT_ENVIRONMENT
mkdir -p /opt/uv_cache/${SLURM_JOB_ID}_${SLURM_PROCID}
export UV_CACHE_DIR=/opt/uv_cache/${SLURM_JOB_ID}_${SLURM_PROCID}
uv run --group automodel --with . \
torchrun \
--nnodes=$SLURM_NNODES \
--nproc-per-node=$NUM_GPUS \
--rdzv_backend=c10d \
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
examples/automodel/finetune/finetune.py \
-c examples/automodel/finetune/wan2_1_t2v_flow_multinode.yamluv run --group automodel --with . \
python examples/automodel/generate/wan_validate.py \
--meta_folder <your_meta_folder> \
--guidance_scale 5 \
--checkpoint ./checkpoints/step_1000 \
--num_samples 10Note: You can use --checkpoint ./checkpoints/LATEST to automatically use the most recent checkpoint.
model:
pretrained_model_name_or_path: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
step_scheduler:
global_batch_size: 8
local_batch_size: 1
num_epochs: 10
ckpt_every_steps: 100
data:
dataloader:
meta_folder: "<your_processed_meta_folder>"
num_workers: 2
optim:
learning_rate: 5e-6
flow_matching:
timestep_sampling: "uniform"
flow_shift: 3.0
fsdp:
dp_size: 8 # Single node: 8 GPUs
checkpoint:
enabled: true
checkpoint_dir: "./checkpoints"fsdp:
dp_size: 16 # 2 nodes × 8 GPUs
dp_replicate_size: 2 # Replicate across 2 nodes| Setting | Fine-tuning | Pretraining |
|---|---|---|
learning_rate |
5e-6 | 5e-5 |
weight_decay |
0.01 | 0.1 |
flow_shift |
3.0 | 2.5 |
logit_std |
1.0 | 1.5 |
| Dataset size | 100s-1000s | 10K+ |
| Component | Minimum | Recommended |
|---|---|---|
| GPU | A100 40GB | A100 80GB / H100 |
| GPUs | 4 | 8+ |
| RAM | 128 GB | 256 GB+ |
| Storage | 500 GB SSD | 2 TB NVMe |
- ✅ Flow Matching: Pure flow matching training
- ✅ Distributed: FSDP2 + Tensor Parallelism
- ✅ Mixed Precision: BF16 by default
- ✅ WandB: Automatic logging
- ✅ Checkpointing: consolidated, and sharded formats
- ✅ Multi-node: SLURM and torchrun support
| Model | Parameters | Parallelization | Status |
|---|---|---|---|
| Wan 2.1 T2V 1.3B | 1.3B | FSDP2 via Automodel + DDP | ✅ |
| Wan 2.1 T2V 14B | 14B | FSDP2 via Automodel + DDP | ✅ |
| FLUX | TBD | TBD | 🔄 In Progress |
Custom parallelization:
fsdp:
tp_size: 2 # Tensor parallel
dp_size: 4 # Data parallelCheckpoint cleanup:
from pathlib import Path
import shutil
def cleanup_old_checkpoints(checkpoint_dir, keep_last_n=3):
checkpoints = sorted(Path(checkpoint_dir).glob("step_*"))
for old_ckpt in checkpoints[:-keep_last_n]:
shutil.rmtree(old_ckpt)