Skip to content

Conversation

@wr0124
Copy link
Collaborator

@wr0124 wr0124 commented Nov 3, 2025

This PR introduces fine-tuning support for the consistency model (CM) using the pretrained DDPM weights. The goal is to improve sample quality and training stability by initializing the CM with pretrained diffusion features.

Key Changes

  • Integrated pretrained DDPM checkpoint loading.
  • Added fine-tuning logic for the consistency model in:

Run using the following command:

python3 -W ignore::FutureWarning -W ignore::UserWarning train.py \
--dataroot   paht/to/data  \
--checkpoints_dir   path/to/ckpt/  \
--name   ddpm_cm  \
--gpu_ids 1  \
--data_relative_paths   \
--model_type cm \
--data_dataset_mode  self_supervised_vid_mask_online  \
--train_batch_size 1  \
--dataaug_no_rotate \
--train_iter_size 16  \
--data_num_threads  16  \
--train_G_ema \
--train_G_lr 0.00002 \
--data_temporal_number_frames  6  \
--data_temporal_frame_step   1  \
--train_optim adamw \
--G_netG unet_vid   \
--data_online_creation_rand_mask_A  \
--output_print_freq 16   \
--output_display_freq 16  \
--data_crop_size 128  \
--data_load_size 128   \
--train_compute_metrics_test   \
--train_metrics_every 16  \
--train_metrics_list PSNR LPIPS SSIM \
--with_amp \
--with_tf32 \
--data_online_creation_crop_size_A 300 \
--data_online_creation_crop_size_B  300  \
--alg_cm_metric_mask \
--train_continue \
--alg_diffusion_ddpm_cm_ft \

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant