Skip to content

Commit 2ce99b6

Browse files
committed
Add spectrum tokenizer
1 parent 0b5c6e4 commit 2ce99b6

5 files changed

Lines changed: 717 additions & 0 deletions

File tree

aion/codecs/modules/convnext.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import torch
2+
3+
from aion.codecs.modules.utils import LayerNorm, GRN
4+
5+
6+
class ConvNextBlock1d(torch.nn.Module):
7+
"""ConvNeXtV2 Block.
8+
Modified to 1D from the original 2D implementation from https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
9+
10+
Args:
11+
dim (int): Number of input channels.
12+
drop_path (float): Stochastic depth rate. Default: 0.0
13+
"""
14+
15+
def __init__(self, dim: int):
16+
super().__init__()
17+
self.dwconv = torch.nn.Conv1d(
18+
dim, dim, kernel_size=7, padding=3, groups=dim
19+
) # depthwise conv
20+
self.norm = LayerNorm(dim, eps=1e-6)
21+
self.pwconv1 = torch.nn.Linear(
22+
dim, 4 * dim
23+
) # pointwise/1x1 convs, implemented with linear layers
24+
self.act = torch.nn.GELU()
25+
self.grn = GRN(4 * dim)
26+
self.pwconv2 = torch.nn.Linear(4 * dim, dim)
27+
28+
def forward(self, x):
29+
y = self.dwconv(x)
30+
y = y.permute(0, 2, 1) # (B, C, N) -> (B, N, C)
31+
y = self.norm(y)
32+
y = self.pwconv1(y)
33+
y = self.act(y)
34+
y = self.grn(y)
35+
y = self.pwconv2(y)
36+
y = y.permute(0, 2, 1) # (B, N, C) -> (B, C, N)
37+
38+
y = x + y
39+
return y
40+
41+
42+
class ConvNextEncoder1d(torch.nn.Module):
43+
r"""ConvNeXt encoder.
44+
45+
Modified from https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py
46+
47+
Args:
48+
in_chans : Number of input image channels. Default: 3
49+
depths : Number of blocks at each stage. Default: [3, 3, 9, 3]
50+
dims : Feature dimension at each stage. Default: [96, 192, 384, 768]
51+
drop_path_rate : Stochastic depth rate. Default: 0.
52+
layer_scale_init_value : Init value for Layer Scale. Default: 1e-6.
53+
"""
54+
55+
def __init__(
56+
self,
57+
in_chans: int = 2,
58+
depths: tuple[int, ...] = (3, 3, 9, 3),
59+
dims: tuple[int, ...] = (96, 192, 384, 768),
60+
):
61+
super().__init__()
62+
assert len(depths) == len(dims), "depths and dims should have the same length"
63+
num_layers = len(depths)
64+
65+
self.downsample_layers = (
66+
torch.nn.ModuleList()
67+
) # stem and 3 intermediate downsampling conv layers
68+
stem = torch.nn.Sequential(
69+
torch.nn.Conv1d(in_chans, dims[0], kernel_size=4, stride=4),
70+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
71+
)
72+
self.downsample_layers.append(stem)
73+
for i in range(num_layers - 1):
74+
downsample_layer = torch.nn.Sequential(
75+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
76+
torch.nn.Conv1d(dims[i], dims[i + 1], kernel_size=2, stride=2),
77+
)
78+
self.downsample_layers.append(downsample_layer)
79+
80+
self.stages = torch.nn.ModuleList()
81+
for i in range(num_layers):
82+
stage = torch.nn.Sequential(
83+
*[
84+
ConvNextBlock1d(
85+
dim=dims[i],
86+
)
87+
for j in range(depths[i])
88+
]
89+
)
90+
self.stages.append(stage)
91+
92+
self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
93+
94+
self.apply(self._init_weights)
95+
96+
def _init_weights(self, m):
97+
if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)):
98+
torch.nn.init.trunc_normal_(m.weight, std=0.02)
99+
torch.nn.init.constant_(m.bias, 0)
100+
101+
def forward(self, x):
102+
for ds, st in zip(self.downsample_layers, self.stages):
103+
x = ds(x)
104+
x = st(x)
105+
return self.norm(x)
106+
107+
108+
class ConvNextDecoder1d(torch.nn.Module):
109+
r"""ConvNeXt decoder. Essentially a mirrored version of the encoder.
110+
111+
Args:
112+
in_chans (int): Number of input image channels. Default: 3
113+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
114+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
115+
drop_path_rate (float): Stochastic depth rate. Default: 0.
116+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
117+
"""
118+
119+
def __init__(
120+
self,
121+
in_chans=768,
122+
depths=[3, 3, 9, 3],
123+
dims=[384, 192, 96, 2],
124+
):
125+
super().__init__()
126+
assert len(depths) == len(dims), "depths and dims should have the same length"
127+
num_layers = len(depths)
128+
129+
self.upsample_layers = torch.nn.ModuleList()
130+
131+
stem = torch.nn.Sequential(
132+
torch.nn.ConvTranspose1d(in_chans, dims[0], kernel_size=2, stride=2),
133+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
134+
)
135+
self.upsample_layers.append(stem)
136+
137+
for i in range(num_layers - 1):
138+
upsample_layer = torch.nn.Sequential(
139+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
140+
torch.nn.ConvTranspose1d(
141+
dims[i],
142+
dims[i + 1],
143+
kernel_size=2 if i < (num_layers - 2) else 4,
144+
stride=2 if i < (num_layers - 2) else 4,
145+
),
146+
)
147+
self.upsample_layers.append(upsample_layer)
148+
149+
self.stages = torch.nn.ModuleList()
150+
for i in range(num_layers):
151+
stage = torch.nn.Sequential(
152+
*[
153+
ConvNextBlock1d(
154+
dim=dims[i],
155+
)
156+
for j in range(depths[i])
157+
]
158+
)
159+
self.stages.append(stage)
160+
161+
self.apply(self._init_weights)
162+
163+
def _init_weights(self, m):
164+
if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)):
165+
torch.nn.init.trunc_normal_(m.weight, std=0.02)
166+
torch.nn.init.constant_(m.bias, 0)
167+
168+
def forward(self, x):
169+
for us, st in zip(self.upsample_layers, self.stages):
170+
x = us(x)
171+
x = st(x)
172+
return x

aion/codecs/modules/spectrum.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import torch
2+
from jaxtyping import Float
3+
4+
5+
def interp1d(
6+
x: Float[torch.Tensor, " b n"],
7+
y: Float[torch.Tensor, " b n"],
8+
xnew: Float[torch.Tensor, " b m"],
9+
mask_value: float | None = 0.0,
10+
) -> Float[torch.Tensor, " b m"]:
11+
"""Linear interpolation of a 1-D tensor using torch.searchsorted.
12+
Assumes that x and xnew are sorted in increasing order.
13+
14+
Args:
15+
x: The x-coordinates of the data points, shape [batch, N].
16+
y: The y-coordinates of the data points, shape [batch, N].
17+
xnew: The x-coordinates of the interpolated points, shape [batch, M].
18+
mask_value: The value to use for xnew outside the range of x.
19+
Returns:
20+
The y-coordinates of the interpolated points, shape [batch, M].
21+
"""
22+
# Find the indices where xnew should be inserted in sorted_x
23+
# Given a point xnew[i] in xnew, return j where x[j] is the nearest point in x such that
24+
# x[j] < xnew[i], except if the nearest point in x has x[j] = xnew[i] then return j - 1.
25+
indices = torch.searchsorted(x, xnew) - 1
26+
27+
# We can define a local linear approx of the grad in each interval
28+
# between two points in x, and we would like to use this to interpolate
29+
# y at those points in xnew which lie inside the range of x, otherwise
30+
# interpolated_y is masked for points in xnew outside the range of x.
31+
# There are len(x) - 1 such intervals between points in x, having indices
32+
# ranging between 0 and len(x) - 2. Points with xnew < min(x) will be
33+
# assigned indices of -1 and points with xnew > max(x) will be assigned
34+
# indices equal to len(x). These are not valid segment indices, but we can
35+
# clamp them to 0 and len(x) - 2 respectively to avoid breaking the
36+
# calculation of the slope variable. The nonsense values we obtain outside
37+
# the range of x will be discarded when masking.
38+
indices = torch.clamp(indices, 0, x.shape[1] - 1 - 1)
39+
40+
slopes = (y[:, :-1] - y[:, 1:]) / (x[:, :-1] - x[:, 1:])
41+
42+
# Interpolate the y-coordinates
43+
ynew = torch.gather(y, 1, indices) + (
44+
xnew - torch.gather(x, 1, indices)
45+
) * torch.gather(slopes, 1, indices)
46+
47+
# Mask out the values that are outside the valid range
48+
mask = (xnew < x[..., 0].reshape(-1, 1)) | (xnew > x[..., -1].reshape(-1, 1))
49+
ynew[mask] = mask_value
50+
51+
return ynew
52+
53+
54+
class LatentSpectralGrid(torch.nn.Module):
55+
def __init__(self, lambda_min: float, resolution: float, num_pixels: int):
56+
"""
57+
Initialize a latent grid to represent spectra from multiple resolutions.
58+
59+
Args:
60+
lambda_min: The minimum wavelength value, in Angstrom.
61+
resolution: The resolution of the spectra, in Angstrom per pixel.
62+
num_pixels: The number of pixels in the spectra.
63+
64+
"""
65+
super().__init__()
66+
self.register_buffer("lambda_min", torch.tensor(lambda_min))
67+
self.register_buffer("resolution", torch.tensor(resolution))
68+
self.register_buffer("length", torch.tensor(num_pixels))
69+
self.register_buffer(
70+
"_wavelength",
71+
(torch.arange(0, num_pixels) * resolution + lambda_min).reshape(
72+
1, num_pixels
73+
),
74+
)
75+
76+
@property
77+
def wavelength(self) -> Float[torch.Tensor, " n"]:
78+
return self._wavelength.squeeze()
79+
80+
def to_observed(
81+
self,
82+
x_latent: Float[torch.Tensor, " b n"],
83+
wavelength: Float[torch.Tensor, " b m"],
84+
) -> Float[torch.Tensor, " b m"]:
85+
"""Transforms the latent representation to the observed wavelength grid.
86+
87+
Args:
88+
x_latent: The latent representation, [batch, self.num_pixels].
89+
wavelength: The observed wavelength grid, [batch, M].
90+
91+
Returns:
92+
The transformed representation on the observed wavelength grid.
93+
"""
94+
b = x_latent.shape[0]
95+
return interp1d(self._wavelength.repeat([b, 1]), x_latent, wavelength)
96+
97+
def to_latent(
98+
self, x_obs: Float[torch.Tensor, "b m"], wavelength: Float[torch.Tensor, "b m"]
99+
) -> Float[torch.Tensor, "b n"]:
100+
"""Transforms the observed representation to the latent wavelength grid.
101+
102+
Args:
103+
x_obs: The observed representation, [batch, N].
104+
wavelength: The wavelength grid, [batch, N].
105+
106+
Returns:
107+
The transformed representation on the latent wavelength grid.
108+
"""
109+
b = x_obs.shape[0]
110+
return interp1d(wavelength, x_obs, self._wavelength.repeat([b, 1]))

aion/codecs/modules/utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from einops import rearrange
4+
5+
6+
class LayerNorm(torch.nn.Module):
7+
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
8+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
9+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
10+
with shape (batch_size, channels, height, width).
11+
"""
12+
13+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
14+
super().__init__()
15+
self.weight = torch.nn.Parameter(torch.ones(normalized_shape))
16+
self.bias = torch.nn.Parameter(torch.zeros(normalized_shape))
17+
self.eps = eps
18+
self.data_format = data_format
19+
if self.data_format not in ["channels_last", "channels_first"]:
20+
raise NotImplementedError
21+
self.normalized_shape = (normalized_shape,)
22+
23+
def forward(self, x):
24+
if self.data_format == "channels_last":
25+
return F.layer_norm(
26+
x, self.normalized_shape, self.weight, self.bias, self.eps
27+
)
28+
elif self.data_format == "channels_first":
29+
x = rearrange(x, "b c ... -> b ... c")
30+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
31+
return rearrange(x, "b ... c -> b c ...")
32+
33+
34+
class GRN(torch.nn.Module):
35+
"""GRN (Global Response Normalization) layer"""
36+
37+
def __init__(self, dim):
38+
super().__init__()
39+
self.gamma = torch.nn.Parameter(torch.zeros(1, 1, dim))
40+
self.beta = torch.nn.Parameter(torch.zeros(1, 1, dim))
41+
42+
def forward(self, x):
43+
Gx = torch.norm(x, p=2, dim=(1,), keepdim=True)
44+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
45+
return self.gamma * (x * Nx) + self.beta + x

0 commit comments

Comments
 (0)