Skip to content

Commit 79da14a

Browse files
committed
Refactored code completely. Proposed modifications to test.
1 parent 11de419 commit 79da14a

File tree

7 files changed

+366
-11
lines changed

7 files changed

+366
-11
lines changed

torchsom/core/som.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ def _update_weights(
156156
for row, col in bmus
157157
]
158158
) # [batch_size, row_neurons, col_neurons]
159+
# ! Modification to test
160+
# # Vectorised: build a tensor of BMU coordinates and compute in one shot
161+
# coords = torch.stack([bmus[:, 0], bmus[:, 1]], dim=1).to(torch.long)
162+
# neighborhoods = self.neighborhood_fn(coords, sigma) # update neighborhood_fn to accept batched coords # [batch_size, row_neurons, col_neurons]
159163

160164
# Reshape for broadcasting
161165
neighborhoods = neighborhoods.view(batch_size, self.x, self.y, 1)
@@ -595,6 +599,21 @@ def build_distance_map(
595599
max_neighbors += len(neighbor_offsets)
596600
all_offsets.append(neighbor_offsets)
597601

602+
# # ! Modification to test
603+
# def _offsets_for_row(r: int) -> List[Tuple[int, int]]:
604+
# if self.topology == "hexagonal":
605+
# merged = []
606+
# for order in range(1, neighborhood_order + 1):
607+
# merged.extend(
608+
# get_hexagonal_offsets(order)["even" if r % 2 == 0 else "odd"]
609+
# )
610+
# return merged
611+
# else:
612+
# merged = []
613+
# for order in range(1, neighborhood_order + 1):
614+
# merged.extend(get_rectangular_offsets(order))
615+
# return merged
616+
598617
# Initialize distance map
599618
distance_matrix = torch.full(
600619
(self.weights.shape[0], self.weights.shape[1], max_neighbors),
@@ -924,6 +943,21 @@ def build_classification_map(
924943
for order in range(1, neighborhood_order + 1):
925944
neighborhood_offsets.extend(get_rectangular_offsets(order))
926945

946+
# ! Method to test
947+
# def _offsets_for_row(r: int) -> List[Tuple[int, int]]:
948+
# if self.topology == "hexagonal":
949+
# merged = []
950+
# for order in range(1, neighborhood_order + 1):
951+
# merged.extend(
952+
# get_hexagonal_offsets(order)["even" if r % 2 == 0 else "odd"]
953+
# )
954+
# return merged
955+
# else:
956+
# merged = []
957+
# for order in range(1, neighborhood_order + 1):
958+
# merged.extend(get_rectangular_offsets(order))
959+
# return merged
960+
927961
# Iterate through each activated neuron
928962
for bmu_pos, sample_indices in bmus_map.items():
929963
if len(sample_indices) > 0:

torchsom/utils/distances.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _euclidean_distance(
4545
torch.Tensor: euclidean distance between input and weights [row_neurons, col_neurons]
4646
"""
4747

48-
return torch.norm(torch.subtract(data, weights), dim=-1)
48+
return torch.max(torch.abs(data - weights), dim=-1).values
4949

5050

5151
def _manhattan_distance(
@@ -79,7 +79,8 @@ def _chebyshev_distance(
7979
torch.Tensor: chebyshev distance between input and weights [row_neurons, col_neurons]
8080
"""
8181

82-
return torch.max(torch.subtract(data, weights), dim=-1).values
82+
# return torch.max(torch.subtract(data, weights), dim=-1).values
83+
return torch.max(torch.abs(data - weights), dim=-1).values
8384

8485

8586
# TODO Check if this method works and if it is more efficient (also ensure it is compatible with batch framework)

torchsom/utils/grid.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,17 @@ def convert_to_axial_coords(
8686
r = row # y axis
8787
return q, r
8888

89+
# ! Modification to test
90+
# # Match the even-r layout used in `adjust_meshgrid_topology`
91+
# # even rows -> shifted left by 0.5
92+
# # odd rows -> no horizontal shift
93+
# if row % 2 == 0: # even-r
94+
# q = col - 0.5 - (row // 2)
95+
# else: # odd-r
96+
# q = col - (row // 2)
97+
# r = row
98+
# return q, r
99+
89100

90101
def axial_distance(
91102
q1: float,

torchsom/utils/initialization.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ def random_init(
3535

3636
return sampled_weights
3737

38+
# ! Modification to test
39+
# # Return value ignores the original weights tensor
40+
# weights.copy_(data[indices])
41+
# return weights
42+
3843
except RuntimeError as e:
3944
raise RuntimeError(f"Random initialization failed: {str(e)}")
4045

torchsom/utils/metrics.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def calculate_topographic_error(
7979
# Calculate distances between each data point and all neurons
8080
distances = distance_fn(data_expanded, weights_expanded)
8181

82+
# ! Modification to test: all the lines below could be vectorized
8283
# Get top 2 BMU indices for each sample
8384
batch_size = distances.shape[0]
8485
_, indices = torch.topk(distances.view(batch_size, -1), k=2, largest=False, dim=1)

torchsom/utils/neighborhood.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def _gaussian(
77
xx: torch.Tensor,
88
yy: torch.Tensor,
9-
c: Tuple[torch.Tensor, torch.Tensor],
9+
c: Tuple[torch.Tensor, torch.Tensor], # ! Tuple[int, int]
1010
sigma: float,
1111
) -> torch.Tensor:
1212
"""Gaussian neighborhood function to update weights.
@@ -55,7 +55,7 @@ def _gaussian(
5555
def _mexican_hat(
5656
xx: torch.Tensor,
5757
yy: torch.Tensor,
58-
c: Tuple[torch.Tensor, torch.Tensor],
58+
c: Tuple[torch.Tensor, torch.Tensor], # ! Tuple[int, int]
5959
sigma: float,
6060
) -> torch.Tensor:
6161
"""
@@ -91,7 +91,6 @@ def _mexican_hat(
9191
Returns:
9292
torch.Tensor: Mexican hat neighborhood weights. Element-wise product standing for the combined influence of mexican neighborhood around center c with a spread sigma [row_neurons, col_neurons].
9393
"""
94-
9594
denum = 2 * sigma * sigma
9695
cst = 1 / (torch.pi * torch.pow(torch.tensor(sigma), 4))
9796
squared_distances = torch.pow(xx - c[0], 2) + torch.pow(
@@ -100,6 +99,16 @@ def _mexican_hat(
10099
exp_distances = torch.exp(-squared_distances / denum)
101100
mexican_hat = cst * (1 - (1 / 2) * squared_distances / (2 * denum)) * exp_distances
102101

102+
# ! Modification to test
103+
# denum = 2 * sigma * sigma
104+
# sigma_t = torch.tensor(sigma, device=xx.device, dtype=xx.dtype)
105+
# cst = 1 / (torch.pi * sigma_t.pow(4))
106+
# squared_distances = torch.pow(xx - c[0], 2) + torch.pow(
107+
# yy - c[1], 2
108+
# ) # Squared distances from center [row_neurons, col_neurons]
109+
# exp_distances = torch.exp(-squared_distances / denum)
110+
# mexican_hat = cst * (1 - (1 / 2) * squared_distances / (2 * denum)) * exp_distances
111+
103112
# Ensure the central peak is exactly 1.0
104113
max_value = mexican_hat[c[0], c[1]]
105114
if max_value > 0:
@@ -110,7 +119,7 @@ def _mexican_hat(
110119
def _bubble(
111120
xx: torch.Tensor,
112121
yy: torch.Tensor,
113-
c: Tuple[torch.Tensor, torch.Tensor],
122+
c: Tuple[torch.Tensor, torch.Tensor], # ! Tuple[int, int]
114123
sigma: float,
115124
) -> torch.Tensor:
116125
"""
@@ -159,7 +168,7 @@ def _bubble(
159168
def _triangle(
160169
xx: torch.Tensor,
161170
yy: torch.Tensor,
162-
c: Tuple[torch.Tensor, torch.Tensor],
171+
c: Tuple[torch.Tensor, torch.Tensor], # ! Tuple[int, int]
163172
sigma: float,
164173
) -> torch.Tensor:
165174
"""

0 commit comments

Comments
 (0)