@@ -681,6 +681,7 @@ def collect_samples(
681681 historical_samples : torch .Tensor ,
682682 historical_outputs : torch .Tensor ,
683683 min_buffer_threshold : int = 50 ,
684+ bmus_idx_map : Dict [Tuple [int , int ], List [int ]] = None ,
684685 ) -> Tuple [torch .Tensor , torch .Tensor ]:
685686 """Collect historical samples similar to the query sample using SOM projection.
686687
@@ -699,11 +700,6 @@ def collect_samples(
699700 historical_samples = historical_samples .to (self .device )
700701 historical_outputs = historical_outputs .to (self .device )
701702
702- # Create a mapping of BMUs to historical sample indices if not already created
703- historical_bmus_idx_map = self .build_bmus_data_map (
704- historical_samples , return_indices = True
705- )
706-
707703 # Initialize collection lists and tracking set
708704 historical_data_list = []
709705 historical_output_list = []
@@ -714,11 +710,8 @@ def collect_samples(
714710 bmu_tuple = (int (bmu_pos [0 ].item ()), int (bmu_pos [1 ].item ()))
715711
716712 # Collect samples (features and outputs) from the query's BMU if any exist
717- if (
718- bmu_tuple in historical_bmus_idx_map
719- and len (historical_bmus_idx_map [bmu_tuple ]) > 0
720- ):
721- for sample_idx in historical_bmus_idx_map [bmu_tuple ]:
713+ if bmu_tuple in bmus_idx_map and len (bmus_idx_map [bmu_tuple ]) > 0 :
714+ for sample_idx in bmus_idx_map [bmu_tuple ]:
722715 historical_data_list .append (historical_samples [sample_idx ])
723716 historical_output_list .append (historical_outputs [sample_idx ])
724717
@@ -748,10 +741,10 @@ def collect_samples(
748741 if (
749742 0 <= neighbor_pos [0 ] < self .x
750743 and 0 <= neighbor_pos [1 ] < self .y
751- and neighbor_pos in historical_bmus_idx_map
752- and len (historical_bmus_idx_map [neighbor_pos ]) > 0
744+ and neighbor_pos in bmus_idx_map
745+ and len (bmus_idx_map [neighbor_pos ]) > 0
753746 ):
754- for sample_idx in historical_bmus_idx_map [neighbor_pos ]:
747+ for sample_idx in bmus_idx_map [neighbor_pos ]:
755748 historical_data_list .append (historical_samples [sample_idx ])
756749 historical_output_list .append (historical_outputs [sample_idx ])
757750
@@ -773,10 +766,7 @@ def collect_samples(
773766 neuron_pos = (row , col )
774767 if neuron_pos in visited_neurons :
775768 continue
776- if (
777- neuron_pos in historical_bmus_idx_map
778- and len (historical_bmus_idx_map [neuron_pos ]) > 0
779- ):
769+ if neuron_pos in bmus_idx_map and len (bmus_idx_map [neuron_pos ]) > 0 :
780770 distance = neurons_distance_map [row , col ].item ()
781771 heapq .heappush (distance_min_heap , (distance , neuron_pos ))
782772
@@ -787,10 +777,10 @@ def collect_samples(
787777 _ , closest_neuron = heapq .heappop (distance_min_heap )
788778 visited_neurons .add (closest_neuron )
789779 if (
790- closest_neuron in historical_bmus_idx_map
791- and len (historical_bmus_idx_map [closest_neuron ]) > 0
780+ closest_neuron in bmus_idx_map
781+ and len (bmus_idx_map [closest_neuron ]) > 0
792782 ):
793- for sample_idx in historical_bmus_idx_map [closest_neuron ]:
783+ for sample_idx in bmus_idx_map [closest_neuron ]:
794784 historical_data_list .append (historical_samples [sample_idx ])
795785 historical_output_list .append (historical_outputs [sample_idx ])
796786 historical_samples_count += 1
0 commit comments