Skip to content

Commit 814818e

Browse files
committed
add forward diffusion process
1 parent 81c50bf commit 814818e

2 files changed

Lines changed: 65 additions & 9 deletions

File tree

exploratory/test_diffusion.ipynb

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 7,
5+
"execution_count": 1,
66
"id": "7683c4c0",
77
"metadata": {},
88
"outputs": [],
@@ -27,9 +27,17 @@
2727
"schedule = LogLinearSchedule(sigma_min=1e-3, sigma_max=1e3)\n"
2828
]
2929
},
30+
{
31+
"cell_type": "markdown",
32+
"id": "f08c9060",
33+
"metadata": {},
34+
"source": [
35+
"# test model initiation "
36+
]
37+
},
3038
{
3139
"cell_type": "code",
32-
"execution_count": 8,
40+
"execution_count": 2,
3341
"id": "8a6a6d9d",
3442
"metadata": {},
3543
"outputs": [],
@@ -47,9 +55,17 @@
4755
")"
4856
]
4957
},
58+
{
59+
"cell_type": "markdown",
60+
"id": "f5a9df7e",
61+
"metadata": {},
62+
"source": [
63+
"# test forward pass "
64+
]
65+
},
5066
{
5167
"cell_type": "code",
52-
"execution_count": 10,
68+
"execution_count": 3,
5369
"id": "229ccb16",
5470
"metadata": {},
5571
"outputs": [
@@ -67,13 +83,35 @@
6783
"print(output.shape)"
6884
]
6985
},
86+
{
87+
"cell_type": "markdown",
88+
"id": "7e7ae519",
89+
"metadata": {},
90+
"source": [
91+
"# test sampler"
92+
]
93+
},
7094
{
7195
"cell_type": "code",
72-
"execution_count": null,
96+
"execution_count": 4,
7397
"id": "f5b4269c",
7498
"metadata": {},
75-
"outputs": [],
76-
"source": []
99+
"outputs": [
100+
{
101+
"name": "stdout",
102+
"output_type": "stream",
103+
"text": [
104+
"torch.Size([4, 3, 32, 32])\n"
105+
]
106+
}
107+
],
108+
"source": [
109+
"x_0 = torch.randn(4, 3, 32, 32) # (B, C, H, W)\n",
110+
"t = torch.rand(4) # Random times\n",
111+
"\n",
112+
"x_t = processor.q_sample(x_0, t)\n",
113+
"print(x_t.shape)"
114+
]
77115
},
78116
{
79117
"cell_type": "code",

src/auto_cast/processors/diffusion.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,24 @@ def map(self, x: Tensor) -> Tensor:
5656
def forward(self, x: Tensor) -> Tensor:
5757
return self.map(x)
5858

59-
60-
61-
59+
def q_sample(self, x_0: Tensor, t:Tensor) -> Tensor:
60+
"""Forward diffusion q(x_t | x_0).
61+
62+
Sample from q(x_t|x_0) = N(alpha_t * x_0, Sigma_t^2*I)
63+
where alpha_t and sigma_t are obtained from the noise schedule.
64+
65+
Args:
66+
x_0: clean data (B, C, H, W)
67+
t: time (B,)
68+
69+
Returns
70+
-------
71+
x_t: noised data at t (B, C, H, W)
72+
"""
73+
alpha_t, sigma_t = self.schedule(t)
74+
# Reshape (B,) to (B, 1, 1, 1) for broadcasting with (B, C, H, W)
75+
alpha_t = alpha_t.view(-1, 1, 1, 1)
76+
sigma_t = sigma_t.view(-1, 1, 1, 1)
77+
noise = torch.randn_like(x_0)
78+
x_t = alpha_t * x_0 + sigma_t * noise
79+
return x_t

0 commit comments

Comments
 (0)