Skip to content

Commit 48c3f87

Browse files
author
nathanpaul.soeding
committed
Added perspective module
1 parent 2b42005 commit 48c3f87

File tree

3 files changed

+282
-1
lines changed

3 files changed

+282
-1
lines changed

neuralpredictors/layers/encoders/firing_rate.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def __init__(
1313
core,
1414
readout,
1515
*,
16+
perspective=None,
1617
shifter=None,
1718
modulator=None,
1819
elu_offset=0.0,
@@ -34,6 +35,7 @@ def __init__(
3435
super().__init__()
3536
self.core = core
3637
self.readout = readout
38+
self.perspective = perspective
3739
self.shifter = shifter
3840
self.modulator = modulator
3941
self.offset = elu_offset
@@ -63,7 +65,14 @@ def forward(
6365
detach_core=False,
6466
**kwargs
6567
):
66-
x = self.core(inputs)
68+
x = inputs
69+
70+
if self.perspective:
71+
if pupil_center is None:
72+
raise ValueError("pupil_center is not given")
73+
x = self.perspective[data_key](x, pupil_center)
74+
75+
x = self.core(x)
6776
if detach_core:
6877
x = x.detach()
6978

neuralpredictors/layers/perspective/__init__.py

Whitespace-only changes.
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
import torch
2+
from torch import nn
3+
4+
5+
def angles_to_rmat3d(angles):
6+
x, y, z = torch.unbind(angles, axis=-1)
7+
N = len(x)
8+
9+
A = torch.eye(3, device=x.device).repeat(N, 1, 1)
10+
B = torch.eye(3, device=x.device).repeat(N, 1, 1)
11+
C = torch.eye(3, device=x.device).repeat(N, 1, 1)
12+
13+
cos_z = torch.cos(z)
14+
sin_z = torch.sin(z)
15+
16+
A[:, 0, 0] = cos_z
17+
A[:, 0, 1] = -sin_z
18+
A[:, 1, 0] = sin_z
19+
A[:, 1, 1] = cos_z
20+
21+
cos_y = torch.cos(y)
22+
sin_y = torch.sin(y)
23+
24+
B[:, 0, 0] = cos_y
25+
B[:, 0, 2] = sin_y
26+
B[:, 2, 0] = -sin_y
27+
B[:, 2, 2] = cos_y
28+
29+
cos_x = torch.cos(x)
30+
sin_x = torch.sin(x)
31+
32+
C[:, 1, 1] = cos_x
33+
C[:, 1, 2] = -sin_x
34+
C[:, 2, 1] = sin_x
35+
C[:, 2, 2] = cos_x
36+
37+
return A @ B @ C
38+
39+
40+
class PixelTransform(nn.Module):
41+
def __init__(self, max_power=1, init_scale=1, init_offset=0, eps=1e-5):
42+
super().__init__()
43+
44+
self.max_power = max_power
45+
self.eps = eps
46+
47+
self.logit = nn.Parameter(torch.zeros(1))
48+
self.scale = nn.Parameter(torch.full([1], init_scale, dtype=torch.float32))
49+
self.offset = nn.Parameter(torch.full([1], init_offset, dtype=torch.float32))
50+
51+
@property
52+
def power(self):
53+
return self.logit.sigmoid() * self.max_power
54+
55+
def forward(self, pixels):
56+
return pixels.add(self.eps).pow(self.power).mul(self.scale).add(self.offset)
57+
58+
59+
class Scale(nn.Module):
60+
def __init__(self, gamma):
61+
super().__init__()
62+
self.gamma = gamma
63+
64+
def forward(self, x):
65+
return x * self.gamma
66+
67+
68+
class Retina(nn.Module):
69+
def __init__(
70+
self,
71+
degree=75,
72+
height=36,
73+
width=64,
74+
dim_in=2,
75+
dim_out=2,
76+
mlp_features=16,
77+
mlp_layers=3,
78+
max_angle=30,
79+
):
80+
super().__init__()
81+
82+
grid = self.create_grid(height, width, degree)
83+
self.register_buffer("grid", grid)
84+
self.max_angle = max_angle / torch.pi * 180
85+
86+
layers = []
87+
88+
features = [dim_in] + [mlp_features] * mlp_layers + [dim_out]
89+
non_linearities = [nn.GELU()] * mlp_layers + [None]
90+
91+
for in_features, out_features, nonlinear in zip(features[:-1], features[1:], non_linearities):
92+
linear = nn.Linear(in_features, out_features)
93+
linear = nn.utils.parametrizations.weight_norm(linear)
94+
nn.init.zeros_(linear.bias)
95+
layers.append(linear)
96+
97+
if nonlinear is not None:
98+
layers.append(nonlinear)
99+
100+
self.mlp = nn.Sequential(*layers)
101+
102+
def create_grid(self, height, width, degree):
103+
# Create isotropic grid of retina
104+
x_axis = torch.linspace(-1, 1, width)
105+
y_axis = torch.linspace(-1, 1, height) * height / width
106+
scale = (width - 1) / width
107+
108+
x, y = torch.meshgrid(
109+
x_axis * scale,
110+
y_axis * scale,
111+
indexing="xy",
112+
)
113+
114+
# Convert to grid of 3D rays corresponding to retina pixels
115+
radians = degree / 180 * torch.pi
116+
117+
r = torch.sqrt(x.pow(2) + y.pow(2)).mul(radians).clip(0, torch.pi / 2)
118+
cos_r = torch.cos(r)
119+
sin_r = torch.sin(r)
120+
121+
theta = torch.atan2(y, x)
122+
cos_theta = torch.cos(theta)
123+
sin_theta = torch.sin(theta)
124+
125+
ray_grid = [
126+
sin_r * cos_theta,
127+
sin_r * sin_theta,
128+
cos_r,
129+
]
130+
131+
return torch.stack(ray_grid, dim=-1)
132+
133+
def rotate_retina(self, rmat):
134+
return torch.einsum("N C D , H W D -> N H W C", rmat, self.grid)
135+
136+
# Take pupil center to return rotated grid of retina rays
137+
def rays(self, pupil_center):
138+
angles = self.mlp(pupil_center)
139+
angles = torch.clip(angles, -self.max_angle, self.max_angle)
140+
141+
pad_zeros = torch.zeros((angles.shape[0], 1), device=angles.device)
142+
angles = torch.concat([angles, pad_zeros], axis=1)
143+
144+
rmat = angles_to_rmat3d(angles)
145+
rays = self.rotate_retina(rmat)
146+
147+
return rays
148+
149+
150+
class Monitor(nn.Module):
151+
def __init__(
152+
self,
153+
init_center_x=0,
154+
init_center_y=0,
155+
init_center_z=0.5,
156+
init_center_std=0.05,
157+
init_angle_x=0,
158+
init_angle_y=0,
159+
init_angle_z=0,
160+
init_angle_std=0.05,
161+
eps=1e-5,
162+
):
163+
super().__init__()
164+
165+
center = [
166+
init_center_x,
167+
init_center_y,
168+
init_center_z,
169+
]
170+
self.center = nn.Parameter(torch.tensor(center, dtype=torch.float32))
171+
172+
angle = [
173+
init_angle_x,
174+
init_angle_y,
175+
init_angle_z,
176+
]
177+
self.angle = nn.Parameter(torch.tensor(angle, dtype=torch.float32))
178+
179+
self.center_std = nn.Parameter(torch.tensor(init_center_std, dtype=torch.float32))
180+
self.angle_std = nn.Parameter(torch.tensor(init_angle_std, dtype=torch.float32))
181+
self.eps = float(eps)
182+
183+
# Optimize position of monitor
184+
def position(self, batch_size):
185+
center = self.center.repeat(batch_size, 1)
186+
angle = self.angle.repeat(batch_size, 1)
187+
188+
if self.training:
189+
center = center + torch.randn(batch_size, 3, device=center.device) * self.center_std
190+
angle = angle + torch.randn(batch_size, 3, device=angle.device) * self.angle_std
191+
192+
x, y, z = angles_to_rmat3d(angle).unbind(2)
193+
194+
return center, x, y, z
195+
196+
# Project rays onto monitor coordinates
197+
def project(self, rays):
198+
center, x, y, z = self.position(len(rays))
199+
200+
a = torch.einsum("N D , N D -> N", z, center)[:, None, None]
201+
b = torch.einsum("N D , N H W D -> N H W", z, rays).clip(self.eps)
202+
203+
c = torch.einsum("N H W , N H W D -> N H W D", a / b, rays)
204+
d = c - center[:, None, None, :]
205+
206+
proj = [
207+
torch.einsum("N H W D , N D -> N H W", d, x),
208+
torch.einsum("N H W D , N D -> N H W", d, y),
209+
]
210+
return torch.stack(proj, dim=3)
211+
212+
# Samples values in img at positions given by grid
213+
def sample_screen(self, img, grid):
214+
_, _, H_in, W_in = img.shape
215+
grid_x, grid_y = grid.unbind(dim=3)
216+
217+
grid_y = grid_y * W_in / H_in
218+
scale = W_in / (W_in - 1)
219+
220+
_, H_out, W_out, _ = grid.shape
221+
grid = [
222+
grid_x * scale * (W_out - 1) / W_out,
223+
grid_y * scale * (H_out - 1) / H_out,
224+
]
225+
226+
return nn.functional.grid_sample(
227+
input=img,
228+
grid=torch.stack(grid, dim=3),
229+
mode="bilinear",
230+
align_corners=False,
231+
)
232+
233+
234+
# Combines Retina and Monitor
235+
236+
237+
class SinglePerspective(nn.Module):
238+
def __init__(self, retina, monitor, pixel_transform, static_power=1.7):
239+
super().__init__()
240+
241+
self.retina = retina
242+
self.monitor = monitor
243+
self.pixel_transform = pixel_transform
244+
self.static_power = static_power
245+
246+
def forward(self, img, pupil_center):
247+
rays = self.retina.rays(pupil_center)
248+
grid = self.monitor.project(rays)
249+
250+
pixels = img
251+
252+
img = (pixels[:, None, 0, :, :] / 255.0).pow(self.static_power)
253+
behaviour = pixels[:, 1:, :, :]
254+
pixels = torch.concat([img, behaviour], axis=1)
255+
256+
pixels = self.monitor.sample_screen(pixels, grid)
257+
258+
img = self.pixel_transform(pixels[:, None, 0, :, :])
259+
behaviour = pixels[:, 1:, :, :]
260+
pixels = torch.concat([img, behaviour], axis=1)
261+
262+
return pixels
263+
264+
265+
class Perspective(nn.ModuleDict):
266+
def __init__(self, data_keys, retina_degree=75, mlp_features=16, mlp_layers=3):
267+
super().__init__()
268+
269+
for k in data_keys:
270+
retina = Retina(degree=retina_degree, mlp_features=mlp_features, mlp_layers=mlp_layers)
271+
monitor = Monitor()
272+
self.add_module(k, SinglePerspective(retina, monitor, PixelTransform()))

0 commit comments

Comments
 (0)