Skip to content

Commit 0d2da5a

Browse files
committed
Updated collect_samples.
1 parent c58cdb3 commit 0d2da5a

File tree

2 files changed

+10
-20
lines changed

2 files changed

+10
-20
lines changed
-18 Bytes
Binary file not shown.

torchsom/core.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,7 @@ def collect_samples(
681681
historical_samples: torch.Tensor,
682682
historical_outputs: torch.Tensor,
683683
min_buffer_threshold: int = 50,
684+
bmus_idx_map: Dict[Tuple[int, int], List[int]] = None,
684685
) -> Tuple[torch.Tensor, torch.Tensor]:
685686
"""Collect historical samples similar to the query sample using SOM projection.
686687
@@ -699,11 +700,6 @@ def collect_samples(
699700
historical_samples = historical_samples.to(self.device)
700701
historical_outputs = historical_outputs.to(self.device)
701702

702-
# Create a mapping of BMUs to historical sample indices if not already created
703-
historical_bmus_idx_map = self.build_bmus_data_map(
704-
historical_samples, return_indices=True
705-
)
706-
707703
# Initialize collection lists and tracking set
708704
historical_data_list = []
709705
historical_output_list = []
@@ -714,11 +710,8 @@ def collect_samples(
714710
bmu_tuple = (int(bmu_pos[0].item()), int(bmu_pos[1].item()))
715711

716712
# Collect samples (features and outputs) from the query's BMU if any exist
717-
if (
718-
bmu_tuple in historical_bmus_idx_map
719-
and len(historical_bmus_idx_map[bmu_tuple]) > 0
720-
):
721-
for sample_idx in historical_bmus_idx_map[bmu_tuple]:
713+
if bmu_tuple in bmus_idx_map and len(bmus_idx_map[bmu_tuple]) > 0:
714+
for sample_idx in bmus_idx_map[bmu_tuple]:
722715
historical_data_list.append(historical_samples[sample_idx])
723716
historical_output_list.append(historical_outputs[sample_idx])
724717

@@ -748,10 +741,10 @@ def collect_samples(
748741
if (
749742
0 <= neighbor_pos[0] < self.x
750743
and 0 <= neighbor_pos[1] < self.y
751-
and neighbor_pos in historical_bmus_idx_map
752-
and len(historical_bmus_idx_map[neighbor_pos]) > 0
744+
and neighbor_pos in bmus_idx_map
745+
and len(bmus_idx_map[neighbor_pos]) > 0
753746
):
754-
for sample_idx in historical_bmus_idx_map[neighbor_pos]:
747+
for sample_idx in bmus_idx_map[neighbor_pos]:
755748
historical_data_list.append(historical_samples[sample_idx])
756749
historical_output_list.append(historical_outputs[sample_idx])
757750

@@ -773,10 +766,7 @@ def collect_samples(
773766
neuron_pos = (row, col)
774767
if neuron_pos in visited_neurons:
775768
continue
776-
if (
777-
neuron_pos in historical_bmus_idx_map
778-
and len(historical_bmus_idx_map[neuron_pos]) > 0
779-
):
769+
if neuron_pos in bmus_idx_map and len(bmus_idx_map[neuron_pos]) > 0:
780770
distance = neurons_distance_map[row, col].item()
781771
heapq.heappush(distance_min_heap, (distance, neuron_pos))
782772

@@ -787,10 +777,10 @@ def collect_samples(
787777
_, closest_neuron = heapq.heappop(distance_min_heap)
788778
visited_neurons.add(closest_neuron)
789779
if (
790-
closest_neuron in historical_bmus_idx_map
791-
and len(historical_bmus_idx_map[closest_neuron]) > 0
780+
closest_neuron in bmus_idx_map
781+
and len(bmus_idx_map[closest_neuron]) > 0
792782
):
793-
for sample_idx in historical_bmus_idx_map[closest_neuron]:
783+
for sample_idx in bmus_idx_map[closest_neuron]:
794784
historical_data_list.append(historical_samples[sample_idx])
795785
historical_output_list.append(historical_outputs[sample_idx])
796786
historical_samples_count += 1

0 commit comments

Comments
 (0)