-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathrun.sh
More file actions
executable file
·124 lines (119 loc) · 3.39 KB
/
Copy pathrun.sh
File metadata and controls
executable file
·124 lines (119 loc) · 3.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# general settings
GPU=0; # gpu to use
SEED=42; # randomness seed for sampling
CHANNELS=64; # number of model base channels (we use 64 for all experiments)
MODE='train'; # train vs sample
DATASET='brats'; # brats or lidc-idri
MODEL='ours_unet_128'; # 'ours_unet_256', 'ours_wnet_128', 'ours_wnet_256'
# settings for sampling/inference
ITERATIONS=0; # training iteration (as a multiple of 1k) checkpoint to use for sampling
SAMPLING_STEPS=0; # number of steps for accelerated sampling, 0 for the default 1000
RUN_DIR=""; # tensorboard dir to be set for the evaluation
# detailed settings (no need to change for reproducing)
if [[ $MODEL == 'ours_unet_128' ]]; then
echo "MODEL: WDM (U-Net) 128 x 128 x 128";
CHANNEL_MULT=1,2,2,4,4;
IMAGE_SIZE=128;
ADDITIVE_SKIP=True;
USE_FREQ=False;
BATCH_SIZE=10;
elif [[ $MODEL == 'ours_unet_256' ]]; then
echo "MODEL: WDM (U-Net) 256 x 256 x 256";
CHANNEL_MULT=1,2,2,4,4,4;
IMAGE_SIZE=256;
ADDITIVE_SKIP=True;
USE_FREQ=False;
BATCH_SIZE=1;
elif [[ $MODEL == 'ours_wnet_128' ]]; then
echo "MODEL: WDM (WavU-Net) 128 x 128 x 128";
CHANNEL_MULT=1,2,2,4,4;
IMAGE_SIZE=128;
ADDITIVE_SKIP=False;
USE_FREQ=True;
BATCH_SIZE=10;
elif [[ $MODEL == 'ours_wnet_256' ]]; then
echo "MODEL: WDM (WavU-Net) 256 x 256 x 256";
CHANNEL_MULT=1,2,2,4,4,4;
IMAGE_SIZE=256;
ADDITIVE_SKIP=False;
USE_FREQ=True;
BATCH_SIZE=1;
else
echo "MODEL TYPE NOT FOUND -> Check the supported configurations again";
fi
# some information and overwriting batch size for sampling
# (overwrite in case you want to sample with a higher batch size)
# no need to change for reproducing
if [[ $MODE == 'sample' ]]; then
echo "MODE: sample"
BATCH_SIZE=1;
elif [[ $MODE == 'train' ]]; then
if [[ $DATASET == 'brats' ]]; then
echo "MODE: training";
echo "DATASET: BRATS";
DATA_DIR=~/wdm-3d/data/BRATS/;
elif [[ $DATASET == 'lidc-idri' ]]; then
echo "MODE: training";
echo "Dataset: LIDC-IDRI";
DATA_DIR=~/wdm-3d/data/LIDC-IDRI/;
else
echo "DATASET NOT FOUND -> Check the supported datasets again";
fi
fi
COMMON="
--dataset=${DATASET}
--num_channels=${CHANNELS}
--class_cond=False
--num_res_blocks=2
--num_heads=1
--learn_sigma=False
--use_scale_shift_norm=False
--attention_resolutions=
--channel_mult=${CHANNEL_MULT}
--diffusion_steps=1000
--noise_schedule=linear
--rescale_learned_sigmas=False
--rescale_timesteps=False
--dims=3
--batch_size=${BATCH_SIZE}
--num_groups=32
--in_channels=8
--out_channels=8
--bottleneck_attention=False
--resample_2d=False
--renormalize=True
--additive_skips=${ADDITIVE_SKIP}
--use_freq=${USE_FREQ}
--predict_xstart=True
"
TRAIN="
--data_dir=${DATA_DIR}
--resume_checkpoint=
--resume_step=0
--image_size=${IMAGE_SIZE}
--use_fp16=False
--lr=1e-5
--save_interval=100000
--num_workers=24
--devices=${GPU}
"
SAMPLE="
--data_dir=${DATA_DIR}
--data_mode=${DATA_MODE}
--seed=${SEED}
--image_size=${IMAGE_SIZE}
--use_fp16=False
--model_path=./${RUN_DIR}/checkpoints/${DATASET}_${ITERATIONS}000.pt
--devices=${GPU}
--output_dir=./results/${RUN_DIR}/${DATASET}_${MODEL}_${ITERATIONS}000/
--num_samples=1000
--use_ddim=False
--sampling_steps=${SAMPLING_STEPS}
--clip_denoised=True
"
# run the python scripts
if [[ $MODE == 'train' ]]; then
python scripts/generation_train.py $TRAIN $COMMON;
else
python scripts/generation_sample.py $SAMPLE $COMMON;
fi