15
15
experiment returns, as reported by `imitation.scripts.analyze`.
16
16
"""
17
17
18
- import numpy as np
19
18
import pandas as pd
20
19
import scipy
21
20
22
21
from imitation .data import types
23
22
24
23
25
- def compare_results_to_baseline (results_file : types .AnyPath ) -> pd .DataFrame :
24
+ def compare_results_to_baseline (results_filename : types .AnyPath ) -> pd .DataFrame :
26
25
"""Compare benchmark results to baseline results.
27
26
28
27
Args:
29
- results_file : Path to a CSV file containing experiment results.
28
+ results_filename : Path to a CSV file containing experiment results.
30
29
31
30
Returns:
32
31
A string containing a table of p-values comparing the experiment results to
33
32
the baseline results.
34
33
"""
35
- data = pd .read_csv (results_file )
36
- data ["imit_return" ] = data ["imit_return_summary" ].apply (
37
- lambda x : float (x .split (" " )[0 ]),
38
- )
39
- summary = (
40
- data [["algo" , "env_name" , "imit_return" ]]
41
- .groupby (["algo" , "env_name" ])
42
- .describe ()
43
- )
44
- summary .columns = summary .columns .get_level_values (1 )
45
- summary = summary .reset_index ()
46
-
47
- # Table 2 (https://arxiv.org/pdf/2211.11972.pdf)
48
- # todo: store results in this repo outside this file
49
- baseline = pd .DataFrame .from_records (
50
- [
51
- {
52
- "algo" : "??exp_command=bc" ,
53
- "env_name" : "seals/Ant-v0" ,
54
- "mean" : 1953 ,
55
- "margin" : 123 ,
56
- },
57
- {
58
- "algo" : "??exp_command=bc" ,
59
- "env_name" : "seals/HalfCheetah-v0" ,
60
- "mean" : 3446 ,
61
- "margin" : 130 ,
62
- },
63
- ],
64
- )
65
- baseline ["count" ] = 5
66
- baseline ["confidence_level" ] = 0.95
67
- # Back out the standard deviation from the margin of error.
34
+ results_summary = load_and_summarize_csv (results_filename )
35
+ baseline_summary = load_and_summarize_csv ("baseline.csv" )
68
36
69
- t_score = scipy .stats .t .ppf (
70
- 1 - ((1 - baseline ["confidence_level" ]) / 2 ),
71
- baseline ["count" ] - 1 ,
72
- )
73
- std_err = baseline ["margin" ] / t_score
74
-
75
- baseline ["std" ] = std_err * np .sqrt (baseline ["count" ])
76
-
77
- comparison = pd .merge (summary , baseline , on = ["algo" , "env_name" ])
37
+ comparison = pd .merge (results_summary , baseline_summary , on = ["algo" , "env_name" ])
78
38
79
39
comparison ["pvalue" ] = scipy .stats .ttest_ind_from_stats (
80
40
comparison ["mean_x" ],
@@ -88,6 +48,30 @@ def compare_results_to_baseline(results_file: types.AnyPath) -> pd.DataFrame:
88
48
return comparison [["algo" , "env_name" , "pvalue" ]]
89
49
90
50
51
+ def load_and_summarize_csv (results_filename : types .AnyPath ) -> pd .DataFrame :
52
+ """Load a results CSV file and summarize the statistics.
53
+
54
+ Args:
55
+ results_filename: Path to a CSV file containing experiment results.
56
+
57
+ Returns:
58
+ A DataFrame containing the mean and standard deviation of the experiment
59
+ returns, grouped by algorithm and environment.
60
+ """
61
+ data = pd .read_csv (results_filename )
62
+ data ["imit_return" ] = data ["imit_return_summary" ].apply (
63
+ lambda x : float (x .split (" " )[0 ]),
64
+ )
65
+ summary = (
66
+ data [["algo" , "env_name" , "imit_return" ]]
67
+ .groupby (["algo" , "env_name" ])
68
+ .describe ()
69
+ )
70
+ summary .columns = summary .columns .get_level_values (1 )
71
+ summary = summary .reset_index ()
72
+ return summary
73
+
74
+
91
75
def main () -> None : # pragma: no cover
92
76
"""Run the script."""
93
77
import sys
0 commit comments