-
Notifications
You must be signed in to change notification settings - Fork 640
/
Copy pathvae.py
214 lines (166 loc) · 7 KB
/
vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import io
import sys
import os
import requests
import PIL
import warnings
import hashlib
import urllib
import yaml
from pathlib import Path
from tqdm import tqdm
from math import sqrt, log
from omegaconf import OmegaConf
from taming.models.vqgan import VQModel, GumbelVQ
import importlib
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange
from dalle_pytorch import distributed_utils
# constants
CACHE_PATH = os.path.expanduser("~/.cache/dalle")
OPENAI_VAE_ENCODER_PATH = 'https://cdn.openai.com/dall-e/encoder.pkl'
OPENAI_VAE_DECODER_PATH = 'https://cdn.openai.com/dall-e/decoder.pkl'
VQGAN_VAE_PATH = 'https://heibox.uni-heidelberg.de/f/140747ba53464f49b476/?dl=1'
VQGAN_VAE_CONFIG_PATH = 'https://heibox.uni-heidelberg.de/f/6ecf2af6c658432c8298/?dl=1'
# helpers methods
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def load_model(path):
with open(path, 'rb') as f:
return torch.load(f, map_location = torch.device('cpu'))
def map_pixels(x, eps = 0.1):
return (1 - 2 * eps) * x + eps
def unmap_pixels(x, eps = 0.1):
return torch.clamp((x - eps) / (1 - 2 * eps), 0, 1)
def download(url, filename = None, root = CACHE_PATH):
if (
not distributed_utils.is_distributed
or distributed_utils.backend.is_local_root_worker()
):
os.makedirs(root, exist_ok = True)
filename = default(filename, os.path.basename(url))
download_target = os.path.join(root, filename)
download_target_tmp = os.path.join(root, f'tmp.{filename}')
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if (
distributed_utils.is_distributed
and not distributed_utils.backend.is_local_root_worker()
and not os.path.isfile(download_target)
):
# If the file doesn't exist yet, wait until it's downloaded by the root worker.
distributed_utils.backend.local_barrier()
if os.path.isfile(download_target):
return download_target
with urllib.request.urlopen(url) as source, open(download_target_tmp, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
os.rename(download_target_tmp, download_target)
if (
distributed_utils.is_distributed
and distributed_utils.backend.is_local_root_worker()
):
distributed_utils.backend.local_barrier()
return download_target
# pretrained Discrete VAE from OpenAI
class OpenAIDiscreteVAE(nn.Module):
def __init__(self, image_size=256):
super().__init__()
self.enc = load_model(download(OPENAI_VAE_ENCODER_PATH))
self.dec = load_model(download(OPENAI_VAE_DECODER_PATH))
self.num_layers = 3
self.image_size = image_size
self.num_tokens = 8192
@torch.no_grad()
def get_codebook_indices(self, img):
img = map_pixels(img)
z_logits = self.enc.blocks(img)
z = torch.argmax(z_logits, dim = 1)
return rearrange(z, 'b h w -> b (h w)')
def decode(self, img_seq):
b, n = img_seq.shape
img_seq = rearrange(img_seq, 'b (h w) -> b h w', h = int(sqrt(n)))
z = F.one_hot(img_seq, num_classes = self.num_tokens)
z = rearrange(z, 'b h w c -> b c h w').float()
x_stats = self.dec(z).float()
x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
return x_rec
def forward(self, img):
raise NotImplemented
# VQGAN from Taming Transformers paper
# https://arxiv.org/abs/2012.09841
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config):
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
class VQGanVAE(nn.Module):
def __init__(self, image_size=256, vqgan_model_path=None, vqgan_config_path=None):
super().__init__()
if vqgan_model_path is None:
model_filename = 'vqgan.1024.model.ckpt'
config_filename = 'vqgan.1024.config.yml'
download(VQGAN_VAE_CONFIG_PATH, config_filename)
download(VQGAN_VAE_PATH, model_filename)
config_path = str(Path(CACHE_PATH) / config_filename)
model_path = str(Path(CACHE_PATH) / model_filename)
else:
model_path = vqgan_model_path
config_path = vqgan_config_path
config = OmegaConf.load(config_path)
model = instantiate_from_config(config["model"])
state = torch.load(model_path, map_location = 'cpu')['state_dict']
model.load_state_dict(state, strict = False)
print(f"Loaded VQGAN from {model_path} and {config_path}")
self.model = model
# f as used in https://github.com/CompVis/taming-transformers#overview-of-pretrained-models
f = config.model.params.ddconfig.resolution / config.model.params.ddconfig.attn_resolutions[0]
self.num_layers = int(log(f)/log(2))
self.image_size = image_size
self.num_tokens = config.model.params.n_embed
self.is_gumbel = isinstance(self.model, GumbelVQ)
self._register_external_parameters()
def _register_external_parameters(self):
"""Register external parameters for DeepSpeed partitioning."""
if (
not distributed_utils.is_distributed
or not distributed_utils.using_backend(
distributed_utils.DeepSpeedBackend)
):
return
deepspeed = distributed_utils.backend.backend_module
deepspeed.zero.register_external_parameter(
self, self.model.quantize.embed.weight if self.is_gumbel else self.model.quantize.embedding.weight)
@torch.no_grad()
def get_codebook_indices(self, img):
b = img.shape[0]
img = (2 * img) - 1
_, _, [_, _, indices] = self.model.encode(img)
if self.is_gumbel:
return rearrange(indices, 'b h w -> b (h w)', b=b)
return rearrange(indices, '(b n) -> b n', b = b)
def decode(self, img_seq):
b, n = img_seq.shape
one_hot_indices = F.one_hot(img_seq, num_classes = self.num_tokens).float()
z = one_hot_indices @ self.model.quantize.embed.weight if self.is_gumbel \
else (one_hot_indices @ self.model.quantize.embedding.weight)
z = rearrange(z, 'b (h w) c -> b c h w', h = int(sqrt(n)))
img = self.model.decode(z)
img = (img.clamp(-1., 1.) + 1) * 0.5
return img
def forward(self, img):
raise NotImplemented