-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmyddpm.py
109 lines (87 loc) · 4.21 KB
/
myddpm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import numpy as np
import torch
import math
"""
MyScheduler: generate noisy action at any time t given clean action x0. Note, x0 can be image also.
MyDDPM: is the sampler, which uses unet to iteratively sample new clean action.
"""
class MyScheduler:
def __init__(self, T=100, beta_start=1e-4, beta_end=0.02, device='cuda'):
"""
Compute all the constants needed for the scheduler. Based on the DDPM equations.
"""
self.betas=torch.linspace(beta_start, beta_end, T).to(device)
self.T = T
self.alphas=1.0 - self.betas
self.device=device
self.sqrt_one_minus_betas = torch.sqrt(1.0 - self.betas)
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
self.sqrt_alpha_bars = torch.sqrt(self.alpha_bars)
self.sqrt_one_minus_alpha_bars = torch.sqrt(1.0 - self.alpha_bars)
self.sqrt_betas = torch.sqrt(self.betas)
self.sqrt_alphas = torch.sqrt(self.alphas)
def extract(self, a, t, x_shape):
"""
collect the values at time t from a and reshape it to x_shape
a: precomputed one dimensional values
t: time steps (scalar or tensor)
x_shape: shape of the tensor to be returned
"""
b=t.shape[0]
out=a.gather(-1, t)
rt=out.reshape(b, *((1,) * (len(x_shape) - 1))) #batch, unpack([1]*rest of the dimension)
return rt
def get_xt(self, x0, t):
"""
compute noisy image xt from x0 at time t
"""
eps=torch.randn_like(x0)
# xt=x0 * self.sqrt_alpha_bars[t] + self.sqrt_one_minus_alpha_bars[t] * eps
xt = x0 * self.extract(self.sqrt_alpha_bars, t, x0.shape) + eps * self.extract(self.sqrt_one_minus_alpha_bars, t, x0.shape)
return xt, eps
class MyDDPM:
def __init__(self, scheduler, noise_predictor_net, device='cuda'):
self.scheduler=scheduler
self.noise_predictor_net=noise_predictor_net
self.device=device
def forward(self, x0):
"""
noise x0 at random time t
return noise, predicted noise
"""
B= x0.shape[0]
t=torch.randint(low=0, high=self.scheduler.T, size=(B,)).long().to(self.device)
xt, eps=self.scheduler.get_xt(x0, t)
eps_pred=self.noise_predictor_net(xt, t)
return xt, eps, eps_pred
def x_t_minus_1_from_x_t(self, t, x_t, eps_theta):
"""
Algorithm 2 in the DDPM paper
"""
sqrt_alpha=self.scheduler.extract(self.scheduler.sqrt_alphas, t, x_t.shape)
sqrt_one_minus_alpha_bar=self.scheduler.extract(self.scheduler.sqrt_one_minus_alpha_bars, t, x_t.shape)
beta=self.scheduler.extract(self.scheduler.betas, t, x_t.shape)
x_t_minus_1 = (1 / sqrt_alpha) * (x_t - ( beta / sqrt_one_minus_alpha_bar ) * eps_theta)
# x_t_minus_1 = (1 / self.scheduler.sqrt_alphas[t]) * (x_t - ( self.scheduler.betas[t] / self.scheduler.sqrt_one_minus_alpha_bars[t] ) * eps_theta)
return x_t_minus_1
def sample_ddpm(self, nsamples, sample_shape, obs_cond):
"""Sampler following the Denoising Diffusion Probabilistic Models method by Ho et al (Algorithm 2)"""
with torch.no_grad():
x = torch.randn(size=(nsamples, *sample_shape)).to(self.device) #start from random noise
xts = [x]
for it in range(self.scheduler.T-1, 0, -1):
t=torch.tensor([it]).repeat_interleave(nsamples, dim=0).long().to(self.device)
# eps_theta = self.noise_predictor_net(x, t)
eps_theta = self.noise_predictor_net(
sample=x,
timestep=t,
global_cond=obs_cond
)
# See DDPM paper between equations 11 and 12
x = self.x_t_minus_1_from_x_t(t, x, eps_theta)
if it > 1: # No noise for t=0
z = torch.randn(size=(nsamples, *sample_shape)).to(self.device)
sqrt_beta=self.scheduler.extract(self.scheduler.sqrt_betas, t, x.shape) #use fixed varience.
x += sqrt_beta* z
xts += [x]
return x, xts