@@ -608,6 +608,7 @@ <h1>Source code for torchsom.visualization.base</h1><div class="highlight"><pre>
608608 < span class ="bp "> self</ span > < span class ="p "> ,</ span >
609609 < span class ="n "> quantization_errors</ span > < span class ="p "> :</ span > < span class ="nb "> list</ span > < span class ="p "> [</ span > < span class ="nb "> float</ span > < span class ="p "> ],</ span >
610610 < span class ="n "> topographic_errors</ span > < span class ="p "> :</ span > < span class ="nb "> list</ span > < span class ="p "> [</ span > < span class ="nb "> float</ span > < span class ="p "> ],</ span >
611+ < span class ="n "> bmus_data_map</ span > < span class ="p "> :</ span > < span class ="nb "> dict</ span > < span class ="p "> [</ span > < span class ="nb "> tuple</ span > < span class ="p "> [</ span > < span class ="nb "> int</ span > < span class ="p "> ,</ span > < span class ="nb "> int</ span > < span class ="p "> ],</ span > < span class ="nb "> list</ span > < span class ="p "> [</ span > < span class ="nb "> int</ span > < span class ="p "> ]],</ span >
611612 < span class ="n "> data</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span >
612613 < span class ="n "> target</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span >
613614 < span class ="n "> component_names</ span > < span class ="p "> :</ span > < span class ="n "> Optional</ span > < span class ="p "> [</ span > < span class ="nb "> list</ span > < span class ="p "> [</ span > < span class ="nb "> str</ span > < span class ="p "> ]]</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span >
@@ -625,6 +626,7 @@ <h1>Source code for torchsom.visualization.base</h1><div class="highlight"><pre>
625626< span class ="sd "> Args:</ span >
626627< span class ="sd "> quantization_errors (list[float]): List of quantization errors [epochs]</ span >
627628< span class ="sd "> topographic_errors (list[float]): List of topographic errors [epochs]</ span >
629+ < span class ="sd "> bmus_data_map (dict[tuple[int, int], list[int]]): Pre-computed BMU to data indices mapping</ span >
628630< span class ="sd "> data (torch.Tensor): Input data tensor [batch_size, n_features]</ span >
629631< span class ="sd "> target (torch.Tensor): Labels tensor for data points [batch_size]</ span >
630632< span class ="sd "> component_names (Optional[list[str]]): Names for each component/feature</ span >
@@ -649,15 +651,30 @@ <h1>Source code for torchsom.visualization.base</h1><div class="highlight"><pre>
649651 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _visualizer</ span > < span class ="o "> .</ span > < span class ="n "> plot_hit_map</ span > < span class ="p "> (</ span > < span class ="n "> data</ span > < span class ="p "> ,</ span > < span class ="n "> save_path</ span > < span class ="o "> =</ span > < span class ="n "> save_path</ span > < span class ="p "> )</ span >
650652 < span class ="k "> if</ span > < span class ="n "> metric_map</ span > < span class ="p "> :</ span >
651653 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _visualizer</ span > < span class ="o "> .</ span > < span class ="n "> plot_metric_map</ span > < span class ="p "> (</ span >
652- < span class ="n "> data</ span > < span class ="p "> ,</ span > < span class ="n "> target</ span > < span class ="p "> ,</ span > < span class ="n "> reduction_parameter</ span > < span class ="o "> =</ span > < span class ="s2 "> "mean"</ span > < span class ="p "> ,</ span > < span class ="n "> save_path</ span > < span class ="o "> =</ span > < span class ="n "> save_path</ span >
654+ < span class ="n "> bmus_data_map</ span > < span class ="o "> =</ span > < span class ="n "> bmus_data_map</ span > < span class ="p "> ,</ span >
655+ < span class ="n "> data</ span > < span class ="o "> =</ span > < span class ="n "> data</ span > < span class ="p "> ,</ span >
656+ < span class ="n "> target</ span > < span class ="o "> =</ span > < span class ="n "> target</ span > < span class ="p "> ,</ span >
657+ < span class ="n "> reduction_parameter</ span > < span class ="o "> =</ span > < span class ="s2 "> "mean"</ span > < span class ="p "> ,</ span >
658+ < span class ="n "> save_path</ span > < span class ="o "> =</ span > < span class ="n "> save_path</ span > < span class ="p "> ,</ span >
653659 < span class ="p "> )</ span >
654660 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _visualizer</ span > < span class ="o "> .</ span > < span class ="n "> plot_metric_map</ span > < span class ="p "> (</ span >
655- < span class ="n "> data</ span > < span class ="p "> ,</ span > < span class ="n "> target</ span > < span class ="p "> ,</ span > < span class ="n "> reduction_parameter</ span > < span class ="o "> =</ span > < span class ="s2 "> "std"</ span > < span class ="p "> ,</ span > < span class ="n "> save_path</ span > < span class ="o "> =</ span > < span class ="n "> save_path</ span >
661+ < span class ="n "> bmus_data_map</ span > < span class ="o "> =</ span > < span class ="n "> bmus_data_map</ span > < span class ="p "> ,</ span >
662+ < span class ="n "> data</ span > < span class ="o "> =</ span > < span class ="n "> data</ span > < span class ="p "> ,</ span >
663+ < span class ="n "> target</ span > < span class ="o "> =</ span > < span class ="n "> target</ span > < span class ="p "> ,</ span >
664+ < span class ="n "> reduction_parameter</ span > < span class ="o "> =</ span > < span class ="s2 "> "std"</ span > < span class ="p "> ,</ span >
665+ < span class ="n "> save_path</ span > < span class ="o "> =</ span > < span class ="n "> save_path</ span > < span class ="p "> ,</ span >
656666 < span class ="p "> )</ span >
657667 < span class ="k "> if</ span > < span class ="n "> score_map</ span > < span class ="p "> :</ span >
658- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _visualizer</ span > < span class ="o "> .</ span > < span class ="n "> plot_score_map</ span > < span class ="p "> (</ span > < span class ="n "> data</ span > < span class ="p "> ,</ span > < span class ="n "> target</ span > < span class ="p "> ,</ span > < span class ="n "> save_path</ span > < span class ="o "> =</ span > < span class ="n "> save_path</ span > < span class ="p "> )</ span >
668+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _visualizer</ span > < span class ="o "> .</ span > < span class ="n "> plot_score_map</ span > < span class ="p "> (</ span >
669+ < span class ="n "> bmus_data_map</ span > < span class ="o "> =</ span > < span class ="n "> bmus_data_map</ span > < span class ="p "> ,</ span >
670+ < span class ="n "> target</ span > < span class ="o "> =</ span > < span class ="n "> target</ span > < span class ="p "> ,</ span >
671+ < span class ="n "> total_samples</ span > < span class ="o "> =</ span > < span class ="n "> data</ span > < span class ="o "> .</ span > < span class ="n "> shape</ span > < span class ="p "> [</ span > < span class ="mi "> 0</ span > < span class ="p "> ],</ span >
672+ < span class ="n "> save_path</ span > < span class ="o "> =</ span > < span class ="n "> save_path</ span > < span class ="p "> ,</ span >
673+ < span class ="p "> )</ span >
659674 < span class ="k "> if</ span > < span class ="n "> rank_map</ span > < span class ="p "> :</ span >
660- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _visualizer</ span > < span class ="o "> .</ span > < span class ="n "> plot_rank_map</ span > < span class ="p "> (</ span > < span class ="n "> data</ span > < span class ="p "> ,</ span > < span class ="n "> target</ span > < span class ="p "> ,</ span > < span class ="n "> save_path</ span > < span class ="o "> =</ span > < span class ="n "> save_path</ span > < span class ="p "> )</ span >
675+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _visualizer</ span > < span class ="o "> .</ span > < span class ="n "> plot_rank_map</ span > < span class ="p "> (</ span >
676+ < span class ="n "> bmus_data_map</ span > < span class ="o "> =</ span > < span class ="n "> bmus_data_map</ span > < span class ="p "> ,</ span > < span class ="n "> target</ span > < span class ="o "> =</ span > < span class ="n "> target</ span > < span class ="p "> ,</ span > < span class ="n "> save_path</ span > < span class ="o "> =</ span > < span class ="n "> save_path</ span >
677+ < span class ="p "> )</ span >
661678 < span class ="k "> if</ span > < span class ="n "> component_planes</ span > < span class ="p "> :</ span >
662679 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _visualizer</ span > < span class ="o "> .</ span > < span class ="n "> plot_component_planes</ span > < span class ="p "> (</ span >
663680 < span class ="n "> component_names</ span > < span class ="o "> =</ span > < span class ="n "> component_names</ span > < span class ="p "> ,</ span > < span class ="n "> save_path</ span > < span class ="o "> =</ span > < span class ="n "> save_path</ span >
0 commit comments