Skip to content

Commit 985ea8c

Browse files
authored
Merge pull request #245 from NathanSoeding/perspective
Perspective module from Wang et al 2025 implementation
2 parents 795c028 + 3f56f53 commit 985ea8c

File tree

3 files changed

+316
-1
lines changed

3 files changed

+316
-1
lines changed

neuralpredictors/layers/encoders/firing_rate.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def __init__(
1313
core,
1414
readout,
1515
*,
16+
perspective=None,
1617
shifter=None,
1718
modulator=None,
1819
elu_offset=0.0,
@@ -34,6 +35,7 @@ def __init__(
3435
super().__init__()
3536
self.core = core
3637
self.readout = readout
38+
self.perspective = perspective
3739
self.shifter = shifter
3840
self.modulator = modulator
3941
self.offset = elu_offset
@@ -63,7 +65,18 @@ def forward(
6365
detach_core=False,
6466
**kwargs
6567
):
66-
x = self.core(inputs)
68+
x = inputs
69+
70+
if self.perspective:
71+
if self.shifter:
72+
raise ValueError("both perspective and shifter cannot be present together, only one should be chosen")
73+
74+
if pupil_center is None:
75+
raise ValueError("pupil_center is not given")
76+
77+
x = self.perspective[data_key](x, pupil_center)
78+
79+
x = self.core(x)
6780
if detach_core:
6881
x = x.detach()
6982

neuralpredictors/layers/perspective/__init__.py

Whitespace-only changes.
Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
'''
2+
code is adjusted from the Wang et al 2025 "Foundation model of neural activity predicts response to new stimulus types" implementation
3+
Specifically
4+
* https://github.com/cajal/fnn/blob/main/fnn/model/pixels.py
5+
* https://github.com/cajal/fnn/blob/main/fnn/model/perspectives.py
6+
'''
7+
import torch
8+
from torch import nn
9+
10+
def angles_to_rmat3d(angles):
11+
"""
12+
Convert batches of Euler angles (x, y, z) to 3D rotation matrices.
13+
14+
Args:
15+
angles (torch.Tensor): Tensor of shape (N, 3), angles in radians.
16+
17+
Returns:
18+
torch.Tensor: Tensor of shape (N, 3, 3), rotation matrices (Rz * Ry * Rx).
19+
"""
20+
x, y, z = torch.unbind(angles, axis=-1)
21+
N = len(x)
22+
23+
A = torch.eye(3, device=x.device).repeat(N, 1, 1)
24+
B = torch.eye(3, device=x.device).repeat(N, 1, 1)
25+
C = torch.eye(3, device=x.device).repeat(N, 1, 1)
26+
27+
cos_z = torch.cos(z)
28+
sin_z = torch.sin(z)
29+
30+
A[:, 0, 0] = cos_z
31+
A[:, 0, 1] = -sin_z
32+
A[:, 1, 0] = sin_z
33+
A[:, 1, 1] = cos_z
34+
35+
cos_y = torch.cos(y)
36+
sin_y = torch.sin(y)
37+
38+
B[:, 0, 0] = cos_y
39+
B[:, 0, 2] = sin_y
40+
B[:, 2, 0] = -sin_y
41+
B[:, 2, 2] = cos_y
42+
43+
cos_x = torch.cos(x)
44+
sin_x = torch.sin(x)
45+
46+
C[:, 1, 1] = cos_x
47+
C[:, 1, 2] = -sin_x
48+
C[:, 2, 1] = sin_x
49+
C[:, 2, 2] = cos_x
50+
51+
return A @ B @ C
52+
53+
54+
class PixelTransform(nn.Module):
55+
"""
56+
Nonlinear pixel intensity transform with learnable power, scale, and offset.
57+
"""
58+
def __init__(self, max_power=1, init_scale=1, init_offset=0, eps=1e-5):
59+
super().__init__()
60+
61+
self.max_power = max_power
62+
self.eps = eps
63+
64+
self.logit = nn.Parameter(torch.zeros(1))
65+
self.scale = nn.Parameter(torch.full([1], init_scale, dtype=torch.float32))
66+
self.offset = nn.Parameter(torch.full([1], init_offset, dtype=torch.float32))
67+
68+
@property
69+
def power(self):
70+
return self.logit.sigmoid() * self.max_power
71+
72+
def forward(self, pixels):
73+
return pixels.add(self.eps).pow(self.power).mul(self.scale).add(self.offset)
74+
75+
76+
class Retina(nn.Module):
77+
"""
78+
Models a retina that maps pupil centers to 3D rays via an MLP.
79+
80+
Args:
81+
degree (float): Field of view in degrees.
82+
height, width (int): Retina grid resolution.
83+
dim_in, dim_out (int): Input/output dimensions for MLP.
84+
mlp_features (int): Hidden feature size.
85+
mlp_layers (int): Number of MLP layers.
86+
max_angle (float): Maximum rotation angle in degrees.
87+
"""
88+
def __init__(
89+
self,
90+
degree=50,
91+
height=36,
92+
width=64,
93+
dim_in=2,
94+
dim_out=2,
95+
mlp_features=16,
96+
mlp_layers=3,
97+
max_angle=30,
98+
):
99+
super().__init__()
100+
101+
grid = self.create_grid(height, width, degree)
102+
self.register_buffer("grid", grid)
103+
self.max_angle = max_angle / torch.pi * 180
104+
105+
layers = []
106+
107+
features = [dim_in] + [mlp_features] * mlp_layers + [dim_out]
108+
non_linearities = [nn.GELU()] * mlp_layers + [None]
109+
110+
for in_features, out_features, nonlinear in zip(features[:-1], features[1:], non_linearities):
111+
linear = nn.Linear(in_features, out_features)
112+
linear = nn.utils.parametrizations.weight_norm(linear)
113+
nn.init.zeros_(linear.bias)
114+
layers.append(linear)
115+
116+
if nonlinear is not None:
117+
layers.append(nonlinear)
118+
119+
self.mlp = nn.Sequential(*layers)
120+
121+
def create_grid(self, height, width, degree):
122+
# Create isotropic grid of retina
123+
x_axis = torch.linspace(-1, 1, width)
124+
y_axis = torch.linspace(-1, 1, height) * height / width
125+
scale = (width - 1) / width
126+
127+
x, y = torch.meshgrid(
128+
x_axis * scale,
129+
y_axis * scale,
130+
indexing="xy",
131+
)
132+
133+
# Convert to grid of 3D rays corresponding to retina pixels
134+
radians = degree / 180 * torch.pi
135+
136+
r = torch.sqrt(x.pow(2) + y.pow(2)).mul(radians).clip(0, torch.pi / 2)
137+
cos_r = torch.cos(r)
138+
sin_r = torch.sin(r)
139+
140+
theta = torch.atan2(y, x)
141+
cos_theta = torch.cos(theta)
142+
sin_theta = torch.sin(theta)
143+
144+
ray_grid = [
145+
sin_r * cos_theta,
146+
sin_r * sin_theta,
147+
cos_r,
148+
]
149+
150+
return torch.stack(ray_grid, dim=-1)
151+
152+
def rotate_retina(self, rmat):
153+
return torch.einsum("N C D , H W D -> N H W C", rmat, self.grid)
154+
155+
# Take pupil center to return rotated grid of retina rays
156+
def rays(self, pupil_center):
157+
angles = self.mlp(pupil_center)
158+
angles = torch.clip(angles, -self.max_angle, self.max_angle)
159+
160+
pad_zeros = torch.zeros((angles.shape[0], 1), device=angles.device)
161+
angles = torch.concat([angles, pad_zeros], axis=1)
162+
163+
rmat = angles_to_rmat3d(angles)
164+
rays = self.rotate_retina(rmat)
165+
166+
return rays
167+
168+
169+
class Monitor(nn.Module):
170+
"""
171+
Models a monitor in 3D space with optimizable position and orientation.
172+
173+
Provides projection of retinal rays onto the monitor plane and sampling of images.
174+
"""
175+
def __init__(
176+
self,
177+
init_center_x=0,
178+
init_center_y=0,
179+
init_center_z=0.5,
180+
init_center_std=0.05,
181+
init_angle_x=0,
182+
init_angle_y=0,
183+
init_angle_z=0,
184+
init_angle_std=0.05,
185+
eps=1e-5,
186+
):
187+
super().__init__()
188+
189+
center = [
190+
init_center_x,
191+
init_center_y,
192+
init_center_z,
193+
]
194+
self.center = nn.Parameter(torch.tensor(center, dtype=torch.float32))
195+
196+
angle = [
197+
init_angle_x,
198+
init_angle_y,
199+
init_angle_z,
200+
]
201+
self.angle = nn.Parameter(torch.tensor(angle, dtype=torch.float32))
202+
203+
self.center_std = nn.Parameter(torch.tensor(init_center_std, dtype=torch.float32))
204+
self.angle_std = nn.Parameter(torch.tensor(init_angle_std, dtype=torch.float32))
205+
self.eps = float(eps)
206+
207+
# Optimize position of monitor
208+
def position(self, batch_size):
209+
center = self.center.repeat(batch_size, 1)
210+
angle = self.angle.repeat(batch_size, 1)
211+
212+
if self.training:
213+
center = center + torch.randn(batch_size, 3, device=center.device) * self.center_std
214+
angle = angle + torch.randn(batch_size, 3, device=angle.device) * self.angle_std
215+
216+
x, y, z = angles_to_rmat3d(angle).unbind(2)
217+
218+
return center, x, y, z
219+
220+
# Project rays onto monitor coordinates
221+
def project(self, rays):
222+
center, x, y, z = self.position(len(rays))
223+
224+
a = torch.einsum("N D , N D -> N", z, center)[:, None, None]
225+
b = torch.einsum("N D , N H W D -> N H W", z, rays).clip(self.eps)
226+
227+
c = torch.einsum("N H W , N H W D -> N H W D", a / b, rays)
228+
d = c - center[:, None, None, :]
229+
230+
proj = [
231+
torch.einsum("N H W D , N D -> N H W", d, x),
232+
torch.einsum("N H W D , N D -> N H W", d, y),
233+
]
234+
return torch.stack(proj, dim=3)
235+
236+
# Samples values in img at positions given by grid
237+
def sample_screen(self, img, grid):
238+
_, _, H_in, W_in = img.shape
239+
grid_x, grid_y = grid.unbind(dim=3)
240+
241+
grid_y = grid_y * W_in / H_in
242+
scale = W_in / (W_in - 1)
243+
244+
_, H_out, W_out, _ = grid.shape
245+
grid = [
246+
grid_x * scale * (W_out - 1) / W_out,
247+
grid_y * scale * (H_out - 1) / H_out,
248+
]
249+
250+
return nn.functional.grid_sample(
251+
input=img,
252+
grid=torch.stack(grid, dim=3),
253+
mode="bilinear",
254+
align_corners=False,
255+
)
256+
257+
258+
# Combines Retina and Monitor
259+
260+
261+
class SinglePerspective(nn.Module):
262+
"""
263+
Combines Retina, Monitor, and PixelTransform to generate a single visual perspective.
264+
"""
265+
def __init__(self, retina, monitor, pixel_transform, static_power=1.7):
266+
super().__init__()
267+
268+
self.retina = retina
269+
self.monitor = monitor
270+
self.pixel_transform = pixel_transform
271+
self.static_power = static_power
272+
273+
def forward(self, img, pupil_center):
274+
rays = self.retina.rays(pupil_center)
275+
grid = self.monitor.project(rays)
276+
277+
pixels = img
278+
279+
img = (pixels[:, None, 0, :, :] / 255.0).pow(self.static_power)
280+
behaviour = pixels[:, 1:, :, :]
281+
pixels = torch.concat([img, behaviour], axis=1)
282+
283+
pixels = self.monitor.sample_screen(pixels, grid)
284+
285+
img = self.pixel_transform(pixels[:, None, 0, :, :])
286+
behaviour = pixels[:, 1:, :, :]
287+
pixels = torch.concat([img, behaviour], axis=1)
288+
289+
return pixels
290+
291+
292+
class Perspective(nn.ModuleDict):
293+
"""
294+
Container for multiple SinglePerspective modules keyed by dataset identifiers.
295+
"""
296+
def __init__(self, data_keys, retina_degree=75, mlp_features=16, mlp_layers=3):
297+
super().__init__()
298+
299+
for k in data_keys:
300+
retina = Retina(degree=retina_degree, mlp_features=mlp_features, mlp_layers=mlp_layers)
301+
monitor = Monitor()
302+
self.add_module(k, SinglePerspective(retina, monitor, PixelTransform()))

0 commit comments

Comments
 (0)