@@ -647,3 +647,150 @@ def relativize_data(
647
647
648
648
def rgb (arr : List [int ]) -> str :
649
649
return "rgb({},{},{})" .format (* arr )
650
+
651
+
652
+ def slice_config_to_trace (
653
+ arm_data ,
654
+ arm_name_to_parameters ,
655
+ f ,
656
+ fit_data ,
657
+ grid ,
658
+ metric ,
659
+ param ,
660
+ rel ,
661
+ setx ,
662
+ sd ,
663
+ is_log ,
664
+ visible ,
665
+ ):
666
+ # format data
667
+ res = relativize_data (f , sd , rel , arm_data , metric )
668
+ f_final = res [0 ]
669
+ sd_final = res [1 ]
670
+
671
+ # get data for standard deviation fill plot
672
+ sd_upper = []
673
+ sd_lower = []
674
+ for i in range (len (sd )):
675
+ sd_upper .append (f_final [i ] + 2 * sd_final [i ])
676
+ sd_lower .append (f_final [i ] - 2 * sd_final [i ])
677
+ grid_rev = list (reversed (grid ))
678
+ sd_lower_rev = list (reversed (sd_lower ))
679
+ sd_x = grid + grid_rev
680
+ sd_y = sd_upper + sd_lower_rev
681
+
682
+ # get data for observed arms and error bars
683
+ arm_x = []
684
+ arm_y = []
685
+ arm_sem = []
686
+ for row in fit_data :
687
+ parameters = arm_name_to_parameters [row ["arm_name" ]]
688
+ plot = True
689
+ for p in setx .keys ():
690
+ if p != param and parameters [p ] != setx [p ]:
691
+ plot = False
692
+ if plot :
693
+ arm_x .append (parameters [param ])
694
+ arm_y .append (row ["mean" ])
695
+ arm_sem .append (row ["sem" ])
696
+
697
+ arm_res = relativize_data (arm_y , arm_sem , rel , arm_data , metric )
698
+ arm_y_final = arm_res [0 ]
699
+ arm_sem_final = [x * 2 for x in arm_res [1 ]]
700
+
701
+ # create traces
702
+ f_trace = {
703
+ "x" : grid ,
704
+ "y" : f_final ,
705
+ "showlegend" : False ,
706
+ "hoverinfo" : "x+y" ,
707
+ "line" : {"color" : "rgba(128, 177, 211, 1)" },
708
+ "visible" : visible ,
709
+ }
710
+
711
+ arms_trace = {
712
+ "x" : arm_x ,
713
+ "y" : arm_y_final ,
714
+ "mode" : "markers" ,
715
+ "error_y" : {
716
+ "type" : "data" ,
717
+ "array" : arm_sem_final ,
718
+ "visible" : True ,
719
+ "color" : "black" ,
720
+ },
721
+ "line" : {"color" : "black" },
722
+ "showlegend" : False ,
723
+ "hoverinfo" : "x+y" ,
724
+ "visible" : visible ,
725
+ }
726
+
727
+ sd_trace = {
728
+ "x" : sd_x ,
729
+ "y" : sd_y ,
730
+ "fill" : "toself" ,
731
+ "fillcolor" : "rgba(128, 177, 211, 0.2)" ,
732
+ "line" : {"color" : "transparent" },
733
+ "showlegend" : False ,
734
+ "hoverinfo" : "none" ,
735
+ "visible" : visible ,
736
+ }
737
+
738
+ traces = [sd_trace , f_trace , arms_trace ]
739
+
740
+ # iterate over out-of-sample arms
741
+ for i , generator_run_name in enumerate (arm_data ["out_of_sample" ].keys ()):
742
+ ax = []
743
+ ay = []
744
+ asem = []
745
+ atext = []
746
+
747
+ for arm_name in arm_data ["out_of_sample" ][generator_run_name ].keys ():
748
+ parameters = arm_data ["out_of_sample" ][generator_run_name ][arm_name ][
749
+ "parameters"
750
+ ]
751
+ plot = True
752
+ for p in setx .keys ():
753
+ if p != param and parameters [p ] != setx [p ]:
754
+ plot = False
755
+ if plot :
756
+ ax .append (parameters [param ])
757
+ ay .append (
758
+ arm_data ["out_of_sample" ][generator_run_name ][arm_name ]["y_hat" ][
759
+ metric
760
+ ]
761
+ )
762
+ asem .append (
763
+ arm_data ["out_of_sample" ][generator_run_name ][arm_name ]["se_hat" ][
764
+ metric
765
+ ]
766
+ )
767
+ atext .append ("<em>Candidate " + arm_name + "</em>" )
768
+
769
+ out_of_sample_arm_res = relativize_data (ay , asem , rel , arm_data , metric )
770
+ ay_final = out_of_sample_arm_res [0 ]
771
+ asem_final = [x * 2 for x in out_of_sample_arm_res [1 ]]
772
+
773
+ traces .append (
774
+ {
775
+ "hoverinfo" : "text" ,
776
+ "legendgroup" : generator_run_name ,
777
+ "marker" : {"color" : "black" , "symbol" : i + 1 , "opacity" : 0.5 },
778
+ "mode" : "markers" ,
779
+ "error_y" : {
780
+ "type" : "data" ,
781
+ "array" : asem_final ,
782
+ "visible" : True ,
783
+ "color" : "black" ,
784
+ },
785
+ "name" : generator_run_name ,
786
+ "text" : atext ,
787
+ "type" : "scatter" ,
788
+ "xaxis" : "x" ,
789
+ "x" : ax ,
790
+ "yaxis" : "y" ,
791
+ "y" : ay_final ,
792
+ "visible" : visible ,
793
+ }
794
+ )
795
+
796
+ return traces
0 commit comments