1+ '''
2+ code is adjusted from the Wang et al 2025 "Foundation model of neural activity predicts response to new stimulus types" implementation
3+ Specifically
4+ * https://github.com/cajal/fnn/blob/main/fnn/model/pixels.py
5+ * https://github.com/cajal/fnn/blob/main/fnn/model/perspectives.py
6+ '''
17import torch
28from torch import nn
39
4-
510def angles_to_rmat3d (angles ):
11+ """
12+ Convert batches of Euler angles (x, y, z) to 3D rotation matrices.
13+
14+ Args:
15+ angles (torch.Tensor): Tensor of shape (N, 3), angles in radians.
16+
17+ Returns:
18+ torch.Tensor: Tensor of shape (N, 3, 3), rotation matrices (Rz * Ry * Rx).
19+ """
620 x , y , z = torch .unbind (angles , axis = - 1 )
721 N = len (x )
822
@@ -38,6 +52,9 @@ def angles_to_rmat3d(angles):
3852
3953
4054class PixelTransform (nn .Module ):
55+ """
56+ Nonlinear pixel intensity transform with learnable power, scale, and offset.
57+ """
4158 def __init__ (self , max_power = 1 , init_scale = 1 , init_offset = 0 , eps = 1e-5 ):
4259 super ().__init__ ()
4360
@@ -54,21 +71,23 @@ def power(self):
5471
5572 def forward (self , pixels ):
5673 return pixels .add (self .eps ).pow (self .power ).mul (self .scale ).add (self .offset )
57-
58-
59- class Scale (nn .Module ):
60- def __init__ (self , gamma ):
61- super ().__init__ ()
62- self .gamma = gamma
63-
64- def forward (self , x ):
65- return x * self .gamma
66-
74+
6775
6876class Retina (nn .Module ):
77+ """
78+ Models a retina that maps pupil centers to 3D rays via an MLP.
79+
80+ Args:
81+ degree (float): Field of view in degrees.
82+ height, width (int): Retina grid resolution.
83+ dim_in, dim_out (int): Input/output dimensions for MLP.
84+ mlp_features (int): Hidden feature size.
85+ mlp_layers (int): Number of MLP layers.
86+ max_angle (float): Maximum rotation angle in degrees.
87+ """
6988 def __init__ (
7089 self ,
71- degree = 75 ,
90+ degree = 50 ,
7291 height = 36 ,
7392 width = 64 ,
7493 dim_in = 2 ,
@@ -148,6 +167,11 @@ def rays(self, pupil_center):
148167
149168
150169class Monitor (nn .Module ):
170+ """
171+ Models a monitor in 3D space with optimizable position and orientation.
172+
173+ Provides projection of retinal rays onto the monitor plane and sampling of images.
174+ """
151175 def __init__ (
152176 self ,
153177 init_center_x = 0 ,
@@ -235,6 +259,9 @@ def sample_screen(self, img, grid):
235259
236260
237261class SinglePerspective (nn .Module ):
262+ """
263+ Combines Retina, Monitor, and PixelTransform to generate a single visual perspective.
264+ """
238265 def __init__ (self , retina , monitor , pixel_transform , static_power = 1.7 ):
239266 super ().__init__ ()
240267
@@ -263,6 +290,9 @@ def forward(self, img, pupil_center):
263290
264291
265292class Perspective (nn .ModuleDict ):
293+ """
294+ Container for multiple SinglePerspective modules keyed by dataset identifiers.
295+ """
266296 def __init__ (self , data_keys , retina_degree = 75 , mlp_features = 16 , mlp_layers = 3 ):
267297 super ().__init__ ()
268298
0 commit comments