Skip to content

Commit 81c50bf

Browse files
committed
added noise scheduler and a notebook to make sure the input runs through the model
1 parent 5e85b00 commit 81c50bf

2 files changed

Lines changed: 144 additions & 3 deletions

File tree

exploratory/test_diffusion.ipynb

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 7,
6+
"id": "7683c4c0",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import torch\n",
11+
"import torch.nn as nn\n",
12+
"from auto_cast.processors.diffusion import DiffusionProcessor, LogLinearSchedule\n",
13+
"\n",
14+
"class SimpleUNet(nn.Module):\n",
15+
" def __init__(self):\n",
16+
" super().__init__()\n",
17+
" self.net = nn.Sequential(\n",
18+
" nn.Conv2d(3,64,3,padding=1),\n",
19+
" nn.ReLU(),\n",
20+
" nn.Conv2d(64,3,3,padding=1)\n",
21+
" )\n",
22+
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
23+
" return self.net(x)\n",
24+
"# Create components\n",
25+
"denoiser_nn = SimpleUNet()\n",
26+
"loss = nn.MSELoss()\n",
27+
"schedule = LogLinearSchedule(sigma_min=1e-3, sigma_max=1e3)\n"
28+
]
29+
},
30+
{
31+
"cell_type": "code",
32+
"execution_count": 8,
33+
"id": "8a6a6d9d",
34+
"metadata": {},
35+
"outputs": [],
36+
"source": [
37+
"from auto_cast.processors.diffusion import DiffusionProcessor, LogLinearSchedule\n",
38+
"\n",
39+
"# Create noise schedule\n",
40+
"schedule = LogLinearSchedule(sigma_min=1e-3, sigma_max=1e3)\n",
41+
"\n",
42+
"# Create processor\n",
43+
"processor = DiffusionProcessor(\n",
44+
" denoiser_nn=denoiser_nn,\n",
45+
" loss=loss,\n",
46+
" schedule=schedule\n",
47+
")"
48+
]
49+
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": 10,
53+
"id": "229ccb16",
54+
"metadata": {},
55+
"outputs": [
56+
{
57+
"name": "stdout",
58+
"output_type": "stream",
59+
"text": [
60+
"torch.Size([4, 3, 32, 32])\n"
61+
]
62+
}
63+
],
64+
"source": [
65+
"x = torch.randn(4, 3, 32, 32) # batch=4, channels=3, height=32, width=32\n",
66+
"output = processor(x)\n",
67+
"print(output.shape)"
68+
]
69+
},
70+
{
71+
"cell_type": "code",
72+
"execution_count": null,
73+
"id": "f5b4269c",
74+
"metadata": {},
75+
"outputs": [],
76+
"source": []
77+
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": null,
81+
"id": "231244a0",
82+
"metadata": {},
83+
"outputs": [],
84+
"source": []
85+
}
86+
],
87+
"metadata": {
88+
"kernelspec": {
89+
"display_name": "auto-cast",
90+
"language": "python",
91+
"name": "python3"
92+
},
93+
"language_info": {
94+
"codemirror_mode": {
95+
"name": "ipython",
96+
"version": 3
97+
},
98+
"file_extension": ".py",
99+
"mimetype": "text/x-python",
100+
"name": "python",
101+
"nbconvert_exporter": "python",
102+
"pygments_lexer": "ipython3",
103+
"version": "3.12.9"
104+
}
105+
},
106+
"nbformat": 4,
107+
"nbformat_minor": 5
108+
}

src/auto_cast/processors/diffusion.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,42 @@
1+
import math
2+
13
import torch
24
import torch.nn as nn
3-
from auto_cast.types import Batch, EncodedBatch, RolloutOutput, Tensor
5+
46
from auto_cast.processors.base import Processor
7+
from auto_cast.types import Batch, EncodedBatch, RolloutOutput, Tensor
58

6-
from azula.noise import Schedule
79

10+
class NoiseSchedule(nn.Module):
11+
"""Noise Schedule Module."""
12+
13+
def forward(self, t: Tensor) -> tuple[Tensor, Tensor]:
14+
"""Get alpha and sigma for given time steps t."""
15+
msg = "Subclasses should implement this method."
16+
raise NotImplementedError(msg)
17+
18+
19+
class LogLinearSchedule(NoiseSchedule):
20+
"""Log-Linear Noise Schedule.
21+
22+
Implements a log-linear schedule for alpha and sigma.
23+
"""
24+
25+
def __init__(self, sigma_min: float = 0.002, sigma_max: float = 80.0):
26+
super().__init__()
27+
self.log_sigma_min = math.log(sigma_min)
28+
self.log_sigma_max = math.log(sigma_max)
29+
30+
def forward(self, t: Tensor) -> tuple[Tensor, Tensor]:
31+
alpha = torch.ones_like(t)
32+
sigma = torch.exp(self.log_sigma_min * (1 - t) + self.log_sigma_max * t)
33+
return alpha, sigma
834

935

1036
class DiffusionProcessor(Processor):
1137
"""Diffusion Processor."""
1238

13-
def __init__(self, denoiser_nn, loss, schedule: Schedule, **kwargs):
39+
def __init__(self, denoiser_nn, loss, schedule: NoiseSchedule, **kwargs):
1440
"""Initialize the DiffusionProcessor.
1541
1642
denoiser_nn: The neural network used for denoising.
@@ -26,3 +52,10 @@ def __init__(self, denoiser_nn, loss, schedule: Schedule, **kwargs):
2652
def map(self, x: Tensor) -> Tensor:
2753
"""Map input window of states/times to output window using denoiser."""
2854
return self.denoiser_nn(x)
55+
56+
def forward(self, x: Tensor) -> Tensor:
57+
return self.map(x)
58+
59+
60+
61+

0 commit comments

Comments
 (0)