Skip to content

Commit 3f56f53

Browse files
author
nathanpaul.soeding
committed
adressed pr feedback: small docstrings for mlp.py, check for shifter and mlp usage
1 parent 48c3f87 commit 3f56f53

File tree

2 files changed

+46
-12
lines changed

2 files changed

+46
-12
lines changed

neuralpredictors/layers/encoders/firing_rate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,12 @@ def forward(
6868
x = inputs
6969

7070
if self.perspective:
71+
if self.shifter:
72+
raise ValueError("both perspective and shifter cannot be present together, only one should be chosen")
73+
7174
if pupil_center is None:
7275
raise ValueError("pupil_center is not given")
76+
7377
x = self.perspective[data_key](x, pupil_center)
7478

7579
x = self.core(x)

neuralpredictors/layers/perspective/mlp.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,22 @@
1+
'''
2+
code is adjusted from the Wang et al 2025 "Foundation model of neural activity predicts response to new stimulus types" implementation
3+
Specifically
4+
* https://github.com/cajal/fnn/blob/main/fnn/model/pixels.py
5+
* https://github.com/cajal/fnn/blob/main/fnn/model/perspectives.py
6+
'''
17
import torch
28
from torch import nn
39

4-
510
def angles_to_rmat3d(angles):
11+
"""
12+
Convert batches of Euler angles (x, y, z) to 3D rotation matrices.
13+
14+
Args:
15+
angles (torch.Tensor): Tensor of shape (N, 3), angles in radians.
16+
17+
Returns:
18+
torch.Tensor: Tensor of shape (N, 3, 3), rotation matrices (Rz * Ry * Rx).
19+
"""
620
x, y, z = torch.unbind(angles, axis=-1)
721
N = len(x)
822

@@ -38,6 +52,9 @@ def angles_to_rmat3d(angles):
3852

3953

4054
class PixelTransform(nn.Module):
55+
"""
56+
Nonlinear pixel intensity transform with learnable power, scale, and offset.
57+
"""
4158
def __init__(self, max_power=1, init_scale=1, init_offset=0, eps=1e-5):
4259
super().__init__()
4360

@@ -54,21 +71,23 @@ def power(self):
5471

5572
def forward(self, pixels):
5673
return pixels.add(self.eps).pow(self.power).mul(self.scale).add(self.offset)
57-
58-
59-
class Scale(nn.Module):
60-
def __init__(self, gamma):
61-
super().__init__()
62-
self.gamma = gamma
63-
64-
def forward(self, x):
65-
return x * self.gamma
66-
74+
6775

6876
class Retina(nn.Module):
77+
"""
78+
Models a retina that maps pupil centers to 3D rays via an MLP.
79+
80+
Args:
81+
degree (float): Field of view in degrees.
82+
height, width (int): Retina grid resolution.
83+
dim_in, dim_out (int): Input/output dimensions for MLP.
84+
mlp_features (int): Hidden feature size.
85+
mlp_layers (int): Number of MLP layers.
86+
max_angle (float): Maximum rotation angle in degrees.
87+
"""
6988
def __init__(
7089
self,
71-
degree=75,
90+
degree=50,
7291
height=36,
7392
width=64,
7493
dim_in=2,
@@ -148,6 +167,11 @@ def rays(self, pupil_center):
148167

149168

150169
class Monitor(nn.Module):
170+
"""
171+
Models a monitor in 3D space with optimizable position and orientation.
172+
173+
Provides projection of retinal rays onto the monitor plane and sampling of images.
174+
"""
151175
def __init__(
152176
self,
153177
init_center_x=0,
@@ -235,6 +259,9 @@ def sample_screen(self, img, grid):
235259

236260

237261
class SinglePerspective(nn.Module):
262+
"""
263+
Combines Retina, Monitor, and PixelTransform to generate a single visual perspective.
264+
"""
238265
def __init__(self, retina, monitor, pixel_transform, static_power=1.7):
239266
super().__init__()
240267

@@ -263,6 +290,9 @@ def forward(self, img, pupil_center):
263290

264291

265292
class Perspective(nn.ModuleDict):
293+
"""
294+
Container for multiple SinglePerspective modules keyed by dataset identifiers.
295+
"""
266296
def __init__(self, data_keys, retina_degree=75, mlp_features=16, mlp_layers=3):
267297
super().__init__()
268298

0 commit comments

Comments
 (0)