Skip to content

Commit 3fe5c01

Browse files
committed
Refactor SOM and metrics modules:
- Removed unused imports in `som.py` for improved clarity and organization. - Commented out unnecessary device transfer in the `SOM` class to prevent potential issues. - Added back essential imports in `metrics.py` for axial distance calculations. - Enhanced inline comments for better understanding of logic in distance normalization.
1 parent 50d1752 commit 3fe5c01

File tree

2 files changed

+5
-11
lines changed

2 files changed

+5
-11
lines changed

torchsom/core/som.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,7 @@
1111

1212
from ..utils.decay import DECAY_FUNCTIONS
1313
from ..utils.distances import DISTANCE_FUNCTIONS
14-
from ..utils.grid import (
15-
adjust_meshgrid_topology,
16-
axial_distance,
17-
convert_to_axial_coords,
18-
create_mesh_grid,
19-
)
14+
from ..utils.grid import adjust_meshgrid_topology, create_mesh_grid
2015
from ..utils.initialization import initialize_weights
2116
from ..utils.metrics import calculate_quantization_error, calculate_topographic_error
2217
from ..utils.neighborhood import NEIGHBORHOOD_FUNCTIONS
@@ -338,7 +333,7 @@ def fit(
338333
Tuple[List[float], List[float]]: Quantization and topographic errors [epoch]
339334
"""
340335

341-
data = data.to(self.device)
336+
# data = data.to(self.device)
342337
dataset = TensorDataset(data)
343338
dataloader = DataLoader(
344339
dataset, batch_size=self.batch_size, shuffle=True, pin_memory=False
@@ -682,7 +677,7 @@ def build_distance_map(
682677
# Normalize the distance map
683678
max_distance = torch.max(
684679
distance_matrix.masked_fill(torch.isnan(distance_matrix), float("-inf"))
685-
)
680+
) # Replace NaNs with -inf to be ignored by max()
686681
return distance_matrix / max_distance if max_distance > 0 else distance_matrix
687682

688683
def build_bmus_data_map(

torchsom/utils/metrics.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import torch
55

6+
from ..utils.grid import axial_distance, convert_to_axial_coords
7+
68

79
def calculate_quantization_error(
810
data: torch.Tensor,
@@ -95,9 +97,6 @@ def calculate_topographic_error(
9597
bmu2_row = int(torch.div(indices[i, 1], y_dim, rounding_mode="floor"))
9698
bmu2_col = int(indices[i, 1] % y_dim)
9799

98-
# Convert to axial coordinates
99-
from ..utils.grid import axial_distance, convert_to_axial_coords
100-
101100
q1, r1 = convert_to_axial_coords(bmu1_row, bmu1_col)
102101
q2, r2 = convert_to_axial_coords(bmu2_row, bmu2_col)
103102

0 commit comments

Comments
 (0)