66def _gaussian (
77 xx : torch .Tensor ,
88 yy : torch .Tensor ,
9- c : Tuple [torch .Tensor , torch .Tensor ],
9+ c : Tuple [torch .Tensor , torch .Tensor ], # ! Tuple[int, int]
1010 sigma : float ,
1111) -> torch .Tensor :
1212 """Gaussian neighborhood function to update weights.
@@ -55,7 +55,7 @@ def _gaussian(
5555def _mexican_hat (
5656 xx : torch .Tensor ,
5757 yy : torch .Tensor ,
58- c : Tuple [torch .Tensor , torch .Tensor ],
58+ c : Tuple [torch .Tensor , torch .Tensor ], # ! Tuple[int, int]
5959 sigma : float ,
6060) -> torch .Tensor :
6161 """
@@ -91,7 +91,6 @@ def _mexican_hat(
9191 Returns:
9292 torch.Tensor: Mexican hat neighborhood weights. Element-wise product standing for the combined influence of mexican neighborhood around center c with a spread sigma [row_neurons, col_neurons].
9393 """
94-
9594 denum = 2 * sigma * sigma
9695 cst = 1 / (torch .pi * torch .pow (torch .tensor (sigma ), 4 ))
9796 squared_distances = torch .pow (xx - c [0 ], 2 ) + torch .pow (
@@ -100,6 +99,16 @@ def _mexican_hat(
10099 exp_distances = torch .exp (- squared_distances / denum )
101100 mexican_hat = cst * (1 - (1 / 2 ) * squared_distances / (2 * denum )) * exp_distances
102101
102+ # ! Modification to test
103+ # denum = 2 * sigma * sigma
104+ # sigma_t = torch.tensor(sigma, device=xx.device, dtype=xx.dtype)
105+ # cst = 1 / (torch.pi * sigma_t.pow(4))
106+ # squared_distances = torch.pow(xx - c[0], 2) + torch.pow(
107+ # yy - c[1], 2
108+ # ) # Squared distances from center [row_neurons, col_neurons]
109+ # exp_distances = torch.exp(-squared_distances / denum)
110+ # mexican_hat = cst * (1 - (1 / 2) * squared_distances / (2 * denum)) * exp_distances
111+
103112 # Ensure the central peak is exactly 1.0
104113 max_value = mexican_hat [c [0 ], c [1 ]]
105114 if max_value > 0 :
@@ -110,7 +119,7 @@ def _mexican_hat(
110119def _bubble (
111120 xx : torch .Tensor ,
112121 yy : torch .Tensor ,
113- c : Tuple [torch .Tensor , torch .Tensor ],
122+ c : Tuple [torch .Tensor , torch .Tensor ], # ! Tuple[int, int]
114123 sigma : float ,
115124) -> torch .Tensor :
116125 """
@@ -159,7 +168,7 @@ def _bubble(
159168def _triangle (
160169 xx : torch .Tensor ,
161170 yy : torch .Tensor ,
162- c : Tuple [torch .Tensor , torch .Tensor ],
171+ c : Tuple [torch .Tensor , torch .Tensor ], # ! Tuple[int, int]
163172 sigma : float ,
164173) -> torch .Tensor :
165174 """
0 commit comments