Skip to content

Commit b2ef15a

Browse files
sgreenburyradka-jmarjanfamilicispragueContiPaolo
committed
Refactoring diffusion
Co-authored-by: Radka Jersakova <r.jersakova@gmail.com> Co-authored-by: Marjan Famili <marjanfamili@users.noreply.github.com> Co-authored-by: Christopher Iliffe Sprague <cisprague@users.noreply.github.com> Co-authored-by: Paolo Conti <ContiPaolo@users.noreply.github.com>
1 parent 02776e7 commit b2ef15a

1 file changed

Lines changed: 9 additions & 21 deletions

File tree

src/auto_cast/processors/diffusion.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(
3131
n_steps_output: int = 4,
3232
n_channels_out: int = 1,
3333
sampler_steps: int = 50,
34+
sampler: str = "euler",
3435
):
3536
super().__init__()
3637
self.teacher_forcing_ratio = teacher_forcing_ratio
@@ -40,6 +41,7 @@ def __init__(
4041
self.n_steps_output = n_steps_output
4142
self.n_channels_out = n_channels_out
4243
self.sampler_steps = sampler_steps
44+
self.sampler = sampler
4345

4446
# Create Azula denoiser with chosen preconditioning
4547
if denoiser_type == "simple":
@@ -54,13 +56,15 @@ def __init__(
5456

5557
def map(self, x: Tensor) -> Tensor:
5658
"""Map input window of states/times to output window using denoiser."""
57-
# if we start from zero at every autoregressive step, the model is asked to
58-
# denoise using t=0, which is a point it has never been trained on.
59-
# self.inference_t = 1e-5
6059
dtype = x.dtype
6160
device = x.device
6261
sampler = self._get_sampler(self.sampler_steps, dtype=dtype, device=device)
6362
B, _, W, H, _ = x.shape
63+
64+
# if we start from zero at every autoregressive step, the model is asked to
65+
# denoise using t=0, which is a point it has never been trained on.
66+
# self.inference_t = 1e-5
67+
# using azula sampler init to create noise at t=1
6468
x_1 = sampler.init(
6569
(B, self.n_steps_output, W, H, self.n_channels_out),
6670
dtype=dtype,
@@ -69,23 +73,6 @@ def map(self, x: Tensor) -> Tensor:
6973
return sampler(x_1, cond=x)
7074

7175
def forward(self, x: Tensor) -> Tensor:
72-
# # Training mode: sample random time and denoise
73-
# if self.train:
74-
# B, _, W, H, _ = x.shape
75-
76-
# # Sample a random time
77-
# t = torch.rand(B, device=x.device) * 0.999 + 0.001 # Avoid t=0 or t=1
78-
79-
# # Create noisy input
80-
# x_noisy = torch.randn(
81-
# B, self.n_steps_output, W, H, self.n_channels_out, device=x.device
82-
# )
83-
84-
# # Denoise (this preserves gradients)
85-
# posterior = self.denoiser(x_noisy, t, cond=x)
86-
# return posterior.mean
87-
88-
# Evaluation mode: use the map function
8976
return self.map(x)
9077

9178
def _denoise(self, x: Tensor, t: Tensor, cond: Tensor) -> Tensor:
@@ -107,6 +94,7 @@ def loss(self, batch: EncodedBatch) -> Tensor:
10794
# Cannot use Azula's built-in weighted loss since ligntning calls forward
10895
loss = self.denoiser.loss(x_0, t=t, cond=x_cond)
10996

97+
# TODO: consider an API for looking at alternative losses
11098
# # Compute weighted loss
11199
# alpha_t, sigma_t = self.schedule(t)
112100
# alpha_t = alpha_t.view(-1, 1, 1, 1, 1) # (B, 1, 1, 1, 1)
@@ -124,11 +112,11 @@ def loss(self, batch: EncodedBatch) -> Tensor:
124112
def _get_sampler(
125113
self,
126114
num_steps: int = 100,
127-
sampler: str = "euler",
128115
eta: float = 0.0,
129116
silent: bool = True,
130117
**sampler_kwargs,
131118
) -> Sampler:
119+
sampler = self.sampler
132120
# Create appropriate Azula sampler
133121
if sampler == "euler":
134122
azula_sampler = EulerSampler(

0 commit comments

Comments
 (0)