This is the training code for the Jax/Flax implementation of Few-shot Image Generation via Cross-domain Correspondence.
- Getting Started
- Preparing Datasets for Training
- Training
- Checkpoints
- Generating Images
- References
- License
You will need Python 3.7 or later.
- Clone the repository:
> git clone https://github.com/matthias-wright/flaxmodels.git - Go into the directory:
> cd flaxmodels/training/few_shot_gan_adaption
- Install Jax with CUDA.
- Install requirements:
> pip install -r requirements.txt
Before training, the images should be stored in a TFRecord dataset. The TFRecord format stores your data as a sequence of bytes, which allows for fast data loading.
Alternatively, you can also use tfds.folder_dataset.ImageFolder on the image directory directly but you will have to replace the tf.data.TFRecordDataset in data_pipeline.py with tfds.folder_dataset.ImageFolder (see this thread for more info).
- Download dataset from here.
- Put all images into a directory:
/path/to/image_dir/ 0.jpg 1.jpg 2.jpg 4.jpg ... - Create TFRecord dataset:
> python dataset_utils/images_to_tfrecords.py --image_dir /path/to/image_dir/ --data_dir /path/to/tfrecord--image_diris the path to the image directory.
--data_diris the path where the TFRecord dataset is stored.
Download checkpoint of source model:
> wget https://www.dropbox.com/s/hyh1k8ixtzy24ye/ffhq_256x256.pickle\?dl\=1 -O ffhq_256x256.pickleStart training:
> CUDA_VISIBLE_DEVICES=a,b,c,d python main.py --data_dir /path/to/tfrecord --source_ckpt_path ffhq_256x256.pickleHere a, b, c, d are the GPU indices. Multi GPU training (data parallelism) works by default and will automatically use all the devices that you make visible.
I use Weights & Biases for logging but you can simply replace it with the logging method of your choice. The logging happens all in the training loop implemented in training.py. To use logging with Weights & Biases, use --wand.
By default, every 1000 training steps the FID score is evaluated for 10.000 images. The checkpoint with the highest FID score is saved. You can change evaluation frequency using the --eval_fid_every argument and the number of images to evaluate the FID score on using --num_fid_images.
You can disable the FID score evaluation using --disable_fid. In that case, a checkpoint will be saved every 2000 steps (can be changed using --save_every).
- Sketches (357,2 MB)
- Amedeo Modigliani (357,2 MB)
- Babies (357,2 MB)
- Otto Dix (357,2 MB)
- Rafael (357,2 MB)
import jax
import numpy as np
import dill as pickle
from PIL import Image
import flaxmodels as fm
ckpt = pickle.load(open('sketches.pickle', 'rb'))
params = ckpt['params_ema_G']
generator = fm.few_shot_gan_adaption.Generator()
# Seed
key = jax.random.PRNGKey(0)
# Input noise
z = jax.random.normal(key, shape=(4, 512))
# Generate images
images, _ = generator.apply(params, z, truncation_psi=0.5, train=False, noise_mode='const')
# Normalize images to be in range [0, 1]
images = (images - np.min(images)) / (np.max(images) - np.min(images))
# Save images
for i in range(images.shape[0]):
Image.fromarray(np.uint8(images[i] * 255)).save(f'image_{i}.jpg')