-
Notifications
You must be signed in to change notification settings - Fork 384
/
Copy patheval.py
executable file
·110 lines (92 loc) · 4.62 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/env python3
"""Evaluate models."""
import math, accelerate, torch
from functools import partial
from fastcore.script import call_parse
from torch.utils import data
from torchvision import datasets, transforms, utils
from tqdm.auto import tqdm
import k_diffusion as K
#sampler = K.sampling.sample_lms
#sampler = K.sampling.sample_euler
sampler = K.sampling.sample_heun
@call_parse
def main(
config:str, # the configuration file
batch_size:int=256, # the batch size
sample_steps:int=50, # number of steps to use when sampling
evaluate_n:int=2000, # the number of samples to draw to evaluate
checkpoint:str='model_00050000.pth', # the path of the checkpoint
sample_n:int=64, # the number of images to sample for demo grids
):
config = K.config.load_config(open(config))
model_cfg = config['model']
dataset_cfg = config['dataset']
# TODO: allow non-square input sizes
assert len(model_cfg['input_size']) == 2 and model_cfg['input_size'][0] == model_cfg['input_size'][1]
size = model_cfg['input_size']
accelerator = accelerate.Accelerator()
device = accelerator.device
print(f'Process {accelerator.process_index} using device: {device}', flush=True)
inner_model = K.config.make_model(config)
print('Parameters:', K.utils.n_params(inner_model))
tf = transforms.Compose([
transforms.Resize(size[0], interpolation=transforms.InterpolationMode.LANCZOS),
transforms.CenterCrop(size[0]),
K.augmentation.KarrasAugmentationPipeline(model_cfg['augment_prob']),
])
if dataset_cfg['type'] == 'imagefolder':
train_set = K.utils.FolderOfImages(dataset_cfg['location'], transform=tf)
elif dataset_cfg['type'] == 'cifar10':
train_set = datasets.CIFAR10(dataset_cfg['location'], train=True, download=True, transform=tf)
elif dataset_cfg['type'] == 'fashion':
train_set = datasets.FashionMNIST(dataset_cfg['location'], train=True, download=True, transform=tf)
elif dataset_cfg['type'] == 'mnist':
train_set = datasets.MNIST(dataset_cfg['location'], train=True, download=True, transform=tf)
elif dataset_cfg['type'] == 'huggingface':
from datasets import load_dataset
train_set = load_dataset(dataset_cfg['location'])
train_set.set_transform(partial(K.utils.hf_datasets_augs_helper, transform=tf, image_key=dataset_cfg['image_key']))
train_set = train_set['train']
else: raise ValueError('Invalid dataset type')
try: print('Number of items in dataset:', len(train_set))
except TypeError: pass
image_key = dataset_cfg.get('image_key', 0)
train_dl = data.DataLoader(train_set, batch_size, shuffle=True, drop_last=True, num_workers=8, persistent_workers=True)
inner_model, train_dl = accelerator.prepare(inner_model, train_dl)
sigma_min = model_cfg['sigma_min']
sigma_max = model_cfg['sigma_max']
model = K.config.make_denoiser_wrapper(config)(inner_model)
# Load checkpoint
print(f'Loading ema from {checkpoint}...')
ckpt = torch.load(checkpoint, map_location='cpu')
accelerator.unwrap_model(model.inner_model).load_state_dict(ckpt['model_ema'])
extractor = K.evaluation.InceptionV3FeatureExtractor(device=device)
train_iter = iter(train_dl)
reals_features = K.evaluation.compute_features(accelerator, lambda x: next(train_iter)[image_key][1], extractor, evaluate_n, batch_size)
del train_iter
@torch.no_grad()
@K.utils.eval_mode(model)
def demo():
tqdm.write('Sampling...')
filename = f'{checkpoint}_eval.png'
x = torch.randn([sample_n, model_cfg['input_channels'], size[0], size[1]], device=device) * sigma_max
sigmas = K.sampling.get_sigmas_karras(sample_steps, sigma_min, sigma_max, rho=7., device=device)
x_0 = sampler(model, x, sigmas)
x_0 = x_0[:sample_n]
grid = utils.make_grid(-x_0, nrow=math.ceil(sample_n ** 0.5), padding=0)
K.utils.to_pil_image(grid).save(filename)
@torch.no_grad()
@K.utils.eval_mode(model)
def evaluate():
tqdm.write('Evaluating...')
sigmas = K.sampling.get_sigmas_karras(sample_steps, sigma_min, sigma_max, rho=7., device=device)
def sample_fn(n):
x = torch.randn([n, model_cfg['input_channels'], size[0], size[1]], device=device) * sigma_max
return sampler(model, x, sigmas)
fakes_features = K.evaluation.compute_features(accelerator, sample_fn, extractor, evaluate_n, batch_size)
fid = K.evaluation.fid(fakes_features, reals_features)
kid = K.evaluation.kid(fakes_features, reals_features)
print(f'FID: {fid.item():g}, KID: {kid.item():g}')
demo()
evaluate()