-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDM_train.py
More file actions
42 lines (37 loc) · 1.29 KB
/
DM_train.py
File metadata and controls
42 lines (37 loc) · 1.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import os
import logging
import time
from DM.denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
if __name__ == '__main__':
now_time = time.strftime('%Y-%m-%d-%H-%M-%S')
now_time_path = os.path.join('experiments', now_time)
os.makedirs(now_time_path, exist_ok=True)
logger = logging.getLogger()
logfile = '{}.log'.format(now_time)
logfile = os.path.join(now_time_path, logfile)
FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s'
log_level = logging.INFO
logging.basicConfig(level=log_level, format=FORMAT, filename=logfile)
logging.root.addHandler(logging.StreamHandler())
model = Unet(channels=4)
diffusion = GaussianDiffusion(
model,
image_size=512,
timesteps=1000,
sampling_timesteps=10,
objective='pred_v',
)
trainer = Trainer(
diffusion,
data_path='data_npz',
source_modality='T1N',
target_modality='T2W',
train_batch_size=16,
gradient_accumulate_every=1, # gradient accumulation steps
train_lr=8e-5,
train_num_steps=200000,
save_and_sample_every=10000,
num_samples=16,
results_folder=os.path.join(now_time_path, 'results'),
)
trainer.train()