Skip to content

Commit 634374a

Browse files
2timesjayfacebook-github-bot
authored andcommitted
Replace slice.js and interact_slice.js with python and generic_plotly.js
Reviewed By: lena-kashtelyan Differential Revision: D16965821 fbshipit-source-id: 9e8e0f75e6fb789cddbbd1e15db2637e3fb4cf73
1 parent bf84a8f commit 634374a

File tree

5 files changed

+297
-173
lines changed

5 files changed

+297
-173
lines changed

ax/plot/helper.py

+147
Original file line numberDiff line numberDiff line change
@@ -647,3 +647,150 @@ def relativize_data(
647647

648648
def rgb(arr: List[int]) -> str:
649649
return "rgb({},{},{})".format(*arr)
650+
651+
652+
def slice_config_to_trace(
653+
arm_data,
654+
arm_name_to_parameters,
655+
f,
656+
fit_data,
657+
grid,
658+
metric,
659+
param,
660+
rel,
661+
setx,
662+
sd,
663+
is_log,
664+
visible,
665+
):
666+
# format data
667+
res = relativize_data(f, sd, rel, arm_data, metric)
668+
f_final = res[0]
669+
sd_final = res[1]
670+
671+
# get data for standard deviation fill plot
672+
sd_upper = []
673+
sd_lower = []
674+
for i in range(len(sd)):
675+
sd_upper.append(f_final[i] + 2 * sd_final[i])
676+
sd_lower.append(f_final[i] - 2 * sd_final[i])
677+
grid_rev = list(reversed(grid))
678+
sd_lower_rev = list(reversed(sd_lower))
679+
sd_x = grid + grid_rev
680+
sd_y = sd_upper + sd_lower_rev
681+
682+
# get data for observed arms and error bars
683+
arm_x = []
684+
arm_y = []
685+
arm_sem = []
686+
for row in fit_data:
687+
parameters = arm_name_to_parameters[row["arm_name"]]
688+
plot = True
689+
for p in setx.keys():
690+
if p != param and parameters[p] != setx[p]:
691+
plot = False
692+
if plot:
693+
arm_x.append(parameters[param])
694+
arm_y.append(row["mean"])
695+
arm_sem.append(row["sem"])
696+
697+
arm_res = relativize_data(arm_y, arm_sem, rel, arm_data, metric)
698+
arm_y_final = arm_res[0]
699+
arm_sem_final = [x * 2 for x in arm_res[1]]
700+
701+
# create traces
702+
f_trace = {
703+
"x": grid,
704+
"y": f_final,
705+
"showlegend": False,
706+
"hoverinfo": "x+y",
707+
"line": {"color": "rgba(128, 177, 211, 1)"},
708+
"visible": visible,
709+
}
710+
711+
arms_trace = {
712+
"x": arm_x,
713+
"y": arm_y_final,
714+
"mode": "markers",
715+
"error_y": {
716+
"type": "data",
717+
"array": arm_sem_final,
718+
"visible": True,
719+
"color": "black",
720+
},
721+
"line": {"color": "black"},
722+
"showlegend": False,
723+
"hoverinfo": "x+y",
724+
"visible": visible,
725+
}
726+
727+
sd_trace = {
728+
"x": sd_x,
729+
"y": sd_y,
730+
"fill": "toself",
731+
"fillcolor": "rgba(128, 177, 211, 0.2)",
732+
"line": {"color": "transparent"},
733+
"showlegend": False,
734+
"hoverinfo": "none",
735+
"visible": visible,
736+
}
737+
738+
traces = [sd_trace, f_trace, arms_trace]
739+
740+
# iterate over out-of-sample arms
741+
for i, generator_run_name in enumerate(arm_data["out_of_sample"].keys()):
742+
ax = []
743+
ay = []
744+
asem = []
745+
atext = []
746+
747+
for arm_name in arm_data["out_of_sample"][generator_run_name].keys():
748+
parameters = arm_data["out_of_sample"][generator_run_name][arm_name][
749+
"parameters"
750+
]
751+
plot = True
752+
for p in setx.keys():
753+
if p != param and parameters[p] != setx[p]:
754+
plot = False
755+
if plot:
756+
ax.append(parameters[param])
757+
ay.append(
758+
arm_data["out_of_sample"][generator_run_name][arm_name]["y_hat"][
759+
metric
760+
]
761+
)
762+
asem.append(
763+
arm_data["out_of_sample"][generator_run_name][arm_name]["se_hat"][
764+
metric
765+
]
766+
)
767+
atext.append("<em>Candidate " + arm_name + "</em>")
768+
769+
out_of_sample_arm_res = relativize_data(ay, asem, rel, arm_data, metric)
770+
ay_final = out_of_sample_arm_res[0]
771+
asem_final = [x * 2 for x in out_of_sample_arm_res[1]]
772+
773+
traces.append(
774+
{
775+
"hoverinfo": "text",
776+
"legendgroup": generator_run_name,
777+
"marker": {"color": "black", "symbol": i + 1, "opacity": 0.5},
778+
"mode": "markers",
779+
"error_y": {
780+
"type": "data",
781+
"array": asem_final,
782+
"visible": True,
783+
"color": "black",
784+
},
785+
"name": generator_run_name,
786+
"text": atext,
787+
"type": "scatter",
788+
"xaxis": "x",
789+
"x": ax,
790+
"yaxis": "y",
791+
"y": ay_final,
792+
"visible": visible,
793+
}
794+
)
795+
796+
return traces

ax/plot/js/interact_slice.js

-105
This file was deleted.

ax/plot/js/slice.js

-61
This file was deleted.

ax/plot/render.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,7 @@ class _AxPlotJSResources(enum.Enum):
3232

3333

3434
# JS-based plots that are supported in Ax should be registered here
35-
Ax_PLOT_REGISTRY: Dict[enum.Enum, str] = {
36-
AxPlotTypes.GENERIC: "generic_plotly.js",
37-
AxPlotTypes.SLICE: "slice.js",
38-
AxPlotTypes.INTERACT_SLICE: "interact_slice.js",
39-
}
35+
Ax_PLOT_REGISTRY: Dict[enum.Enum, str] = {AxPlotTypes.GENERIC: "generic_plotly.js"}
4036

4137

4238
def _load_js_resource(resource_type: _AxPlotJSResources) -> str:

0 commit comments

Comments
 (0)