Skip to content

Latest commit

 

History

History
52 lines (37 loc) · 4.46 KB

File metadata and controls

52 lines (37 loc) · 4.46 KB

Adversarial Flow Models

Official PyTorch implementation of Adversarial Flow Models.

Adversarial Flow Models
Shanchuan Lin, Ceyuan Yang, Zhijie Lin, Hao Chen, Haoqi Fan
ByteDance Seed

Flow

Abstract

We present adversarial flow models, a class of generative models that unifies adversarial models and flow models. Our method supports native one-step or multi-step generation and is trained using the adversarial objective. Unlike traditional GANs, where the generator learns an arbitrary transport plan between the noise and the data distributions, our generator learns a deterministic noise-to-data mapping, which is the same optimal transport as in flow-matching models. This significantly stabilizes adversarial training. Also, unlike consistency-based methods, our model directly learns one-step or few-step generation without needing to learn the intermediate timesteps of the probability flow for propagation. This saves model capacity, reduces training iterations, and avoids error accumulation. Under the same 1NFE setting on ImageNet-256px, our B/2 model approaches the performance of consistency-based XL/2 models, while our XL/2 model creates a new best FID of 2.38. We additionally show the possibility of end-to-end training of 56-layer and 112-layer models through depth repetition without any intermediate supervision, and achieve FIDs of 2.08 and 1.94 using a single forward pass, surpassing their 2NFE and 4NFE counterparts.

Playground

Try adversarial flow on 1D Gaussian mixture.

Code

Please download dit.py from the original DiT repo and place it under models/dit.py, which is licensed under Attribution-NonCommercial 4.0 International.

Checkpoints

Download checkpoints.

  • models/ Pre-trained ImageNet-256px checkpoints.
  • eval/ Pre-generated 50k samples used for FID evaluation. The npz format follows the evaluation script provided by ADM.
  • misc/ contains VAE and other checkpoints used in training.

Generate

The generation configurations are provided in /configs/generate. Please download the pretrained checkpoint and change the yaml to point to the checkpoint. Run the command below to generate 50k samples for FID evaluation.

python3 main.py configs/generate/generate_1nfe.yaml

Or use multiple GPUs

TORCHRUN main.py configs/generate/generate_1nfe.yaml

Note that TORCHRUN denotes the torchrun command with your GPU configuration.

Training

The training configurations are provided in /configs/train.

TORCHRUN main.py configs/train/train_1nfe.yaml

Our train scripts are refactored for public release and are provided for reference purposes. Our codebase was originally written to run on our internal platform, which uses the Parquet dataset format and uses wandb for logging. If you want to use our code for training, you may need to adapt the dataloading logic and logging logic to your own dataset and logging framework.

  • We used Parquet format for storing and loading the dataset on HDFS. We packed 1,281,167 ImageNet training samples into 256 .parquet files; each .parquet file has 69 row groups. Our dataloading code loops over all samples in an infinite loop without the concept of an epoch. G and D updates are counted as separate iterations. So 1281167 / 256bs * 2iter = 10009 iter per epoch. 1m iterations is approximately 100 epochs.

  • The training schedule is provided in Table 11 of the paper. The current approach still requires more manual intervention. This is a limitation we hope to improve in future work.

Founded in 2023, ByteDance Seed Team is dedicated to crafting the industry's most advanced AI foundation models. The team aspires to become a world-class research team and make significant contributions to the advancement of science and society.