@@ -415,7 +415,7 @@ <h1>Source code for torchsom.core.som</h1><div class="highlight"><pre>
415415< span > </ span > < span class ="sd "> """PyTorch implementation of classic Self Organizing Maps using batch learning."""</ span >
416416
417417< span class ="kn "> import</ span > < span class ="w "> </ span > < span class ="nn "> warnings</ span >
418- < span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> typing</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="n "> Any</ span > < span class ="p "> ,</ span > < span class ="n "> Callable</ span > < span class ="p "> ,</ span > < span class ="n "> Optional</ span >
418+ < span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> typing</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="n "> Any</ span > < span class ="p "> ,</ span > < span class ="n "> Callable</ span > < span class ="p "> ,</ span > < span class ="n "> Optional</ span > < span class =" p " > , </ span > < span class =" n " > Union </ span >
419419
420420< span class ="kn "> import</ span > < span class ="w "> </ span > < span class ="nn "> torch</ span >
421421< span class ="kn "> import</ span > < span class ="w "> </ span > < span class ="nn "> torch.nn</ span > < span class ="w "> </ span > < span class ="k "> as</ span > < span class ="w "> </ span > < span class ="nn "> nn</ span >
@@ -848,7 +848,11 @@ <h1>Source code for torchsom.core.som</h1><div class="highlight"><pre>
848848 < span class ="n "> historical_outputs</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span >
849849 < span class ="n "> bmus_idx_map</ span > < span class ="p "> :</ span > < span class ="nb "> dict</ span > < span class ="p "> [</ span > < span class ="nb "> tuple</ span > < span class ="p "> [</ span > < span class ="nb "> int</ span > < span class ="p "> ,</ span > < span class ="nb "> int</ span > < span class ="p "> ],</ span > < span class ="nb "> list</ span > < span class ="p "> [</ span > < span class ="nb "> int</ span > < span class ="p "> ]],</ span >
850850 < span class ="n "> min_buffer_threshold</ span > < span class ="p "> :</ span > < span class ="nb "> int</ span > < span class ="o "> =</ span > < span class ="mi "> 50</ span > < span class ="p "> ,</ span >
851- < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="nb "> tuple</ span > < span class ="p "> [</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ]:</ span >
851+ < span class ="n "> return_indices</ span > < span class ="p "> :</ span > < span class ="nb "> bool</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span >
852+ < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> Union</ span > < span class ="p "> [</ span >
853+ < span class ="nb "> tuple</ span > < span class ="p "> [</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ],</ span >
854+ < span class ="nb "> tuple</ span > < span class ="p "> [</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ],</ span >
855+ < span class ="p "> ]:</ span >
852856< span class ="w "> </ span > < span class ="sd "> """Collect historical samples similar to the query sample using SOM projection.</ span >
853857
854858< span class ="sd "> Args:</ span >
@@ -857,6 +861,11 @@ <h1>Source code for torchsom.core.som</h1><div class="highlight"><pre>
857861< span class ="sd "> historical_outputs (torch.Tensor): Historical outputs tensor [num_samples]</ span >
858862< span class ="sd "> bmus_idx_map (dict[tuple[int, int], list[int]]): BMU to data indices mapping</ span >
859863< span class ="sd "> min_buffer_threshold (int): Minimum buffer threshold</ span >
864+ < span class ="sd "> return_indices (bool): If True, also return the indices of collected samples</ span >
865+
866+ < span class ="sd "> Returns:</ span >
867+ < span class ="sd "> If return_indices is False: (historical_data_buffer, historical_output_buffer)</ span >
868+ < span class ="sd "> If return_indices is True: (historical_data_buffer, historical_output_buffer, indices_tensor)</ span >
860869< span class ="sd "> """</ span >
861870 < span class ="n "> query_sample</ span > < span class ="o "> =</ span > < span class ="n "> query_sample</ span > < span class ="o "> .</ span > < span class ="n "> to</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> )</ span >
862871 < span class ="n "> bmu_pos</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> identify_bmus</ span > < span class ="p "> (</ span > < span class ="n "> query_sample</ span > < span class ="p "> )</ span >
@@ -907,6 +916,8 @@ <h1>Source code for torchsom.core.som</h1><div class="highlight"><pre>
907916 < span class ="p "> )</ span >
908917 < span class ="n "> historical_data_buffer</ span > < span class ="o "> =</ span > < span class ="n "> historical_samples</ span > < span class ="p "> [</ span > < span class ="n "> indices_tensor</ span > < span class ="p "> ]</ span >
909918 < span class ="n "> historical_output_buffer</ span > < span class ="o "> =</ span > < span class ="n "> historical_outputs</ span > < span class ="p "> [</ span > < span class ="n "> indices_tensor</ span > < span class ="p "> ]</ span > < span class ="o "> .</ span > < span class ="n "> view</ span > < span class ="p "> (</ span > < span class ="o "> -</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="mi "> 1</ span > < span class ="p "> )</ span >
919+ < span class ="k "> if</ span > < span class ="n "> return_indices</ span > < span class ="p "> :</ span >
920+ < span class ="k "> return</ span > < span class ="n "> historical_data_buffer</ span > < span class ="p "> ,</ span > < span class ="n "> historical_output_buffer</ span > < span class ="p "> ,</ span > < span class ="n "> indices_tensor</ span >
910921 < span class ="k "> return</ span > < span class ="n "> historical_data_buffer</ span > < span class ="p "> ,</ span > < span class ="n "> historical_output_buffer</ span > </ div >
911922
912923
0 commit comments