Skip to content

Commit 761a7e4

Browse files
committed
handle command names
1 parent 27b9d82 commit 761a7e4

File tree

3 files changed

+45
-57
lines changed

3 files changed

+45
-57
lines changed

src/imitation/scripts/analyze.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -152,16 +152,20 @@ def _get_exp_command(sd: sacred_util.SacredDicts) -> str:
152152
def _get_algo_name(sd: sacred_util.SacredDicts) -> str:
153153
exp_command = _get_exp_command(sd)
154154

155-
if exp_command == "gail":
156-
return "GAIL"
157-
elif exp_command == "airl":
158-
return "AIRL"
159-
elif exp_command == "train_bc":
160-
return "BC"
161-
elif exp_command == "train_dagger":
162-
return "DAgger"
155+
COMMAND_TO_ALGO = {
156+
"train_bc": "BC",
157+
"bc": "BC",
158+
"train_dagger": "DAgger",
159+
"dagger": "DAgger",
160+
"gail": "GAIL",
161+
"airl": "AIRL",
162+
"preference_comparisons": "Preference Comparisons",
163+
}
164+
165+
if exp_command.lower() in COMMAND_TO_ALGO.keys():
166+
return COMMAND_TO_ALGO[exp_command.lower()]
163167
else:
164-
return f"??exp_command={exp_command}"
168+
raise ValueError(f"Unknown command: {exp_command}")
165169

166170

167171
def _return_summaries(sd: sacred_util.SacredDicts) -> dict:

src/imitation/scripts/compare_to_baseline.py

Lines changed: 29 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,66 +15,26 @@
1515
experiment returns, as reported by `imitation.scripts.analyze`.
1616
"""
1717

18-
import numpy as np
1918
import pandas as pd
2019
import scipy
2120

2221
from imitation.data import types
2322

2423

25-
def compare_results_to_baseline(results_file: types.AnyPath) -> pd.DataFrame:
24+
def compare_results_to_baseline(results_filename: types.AnyPath) -> pd.DataFrame:
2625
"""Compare benchmark results to baseline results.
2726
2827
Args:
29-
results_file: Path to a CSV file containing experiment results.
28+
results_filename: Path to a CSV file containing experiment results.
3029
3130
Returns:
3231
A string containing a table of p-values comparing the experiment results to
3332
the baseline results.
3433
"""
35-
data = pd.read_csv(results_file)
36-
data["imit_return"] = data["imit_return_summary"].apply(
37-
lambda x: float(x.split(" ")[0]),
38-
)
39-
summary = (
40-
data[["algo", "env_name", "imit_return"]]
41-
.groupby(["algo", "env_name"])
42-
.describe()
43-
)
44-
summary.columns = summary.columns.get_level_values(1)
45-
summary = summary.reset_index()
46-
47-
# Table 2 (https://arxiv.org/pdf/2211.11972.pdf)
48-
# todo: store results in this repo outside this file
49-
baseline = pd.DataFrame.from_records(
50-
[
51-
{
52-
"algo": "??exp_command=bc",
53-
"env_name": "seals/Ant-v0",
54-
"mean": 1953,
55-
"margin": 123,
56-
},
57-
{
58-
"algo": "??exp_command=bc",
59-
"env_name": "seals/HalfCheetah-v0",
60-
"mean": 3446,
61-
"margin": 130,
62-
},
63-
],
64-
)
65-
baseline["count"] = 5
66-
baseline["confidence_level"] = 0.95
67-
# Back out the standard deviation from the margin of error.
34+
results_summary = load_and_summarize_csv(results_filename)
35+
baseline_summary = load_and_summarize_csv("baseline.csv")
6836

69-
t_score = scipy.stats.t.ppf(
70-
1 - ((1 - baseline["confidence_level"]) / 2),
71-
baseline["count"] - 1,
72-
)
73-
std_err = baseline["margin"] / t_score
74-
75-
baseline["std"] = std_err * np.sqrt(baseline["count"])
76-
77-
comparison = pd.merge(summary, baseline, on=["algo", "env_name"])
37+
comparison = pd.merge(results_summary, baseline_summary, on=["algo", "env_name"])
7838

7939
comparison["pvalue"] = scipy.stats.ttest_ind_from_stats(
8040
comparison["mean_x"],
@@ -88,6 +48,30 @@ def compare_results_to_baseline(results_file: types.AnyPath) -> pd.DataFrame:
8848
return comparison[["algo", "env_name", "pvalue"]]
8949

9050

51+
def load_and_summarize_csv(results_filename: types.AnyPath) -> pd.DataFrame:
52+
"""Load a results CSV file and summarize the statistics.
53+
54+
Args:
55+
results_filename: Path to a CSV file containing experiment results.
56+
57+
Returns:
58+
A DataFrame containing the mean and standard deviation of the experiment
59+
returns, grouped by algorithm and environment.
60+
"""
61+
data = pd.read_csv(results_filename)
62+
data["imit_return"] = data["imit_return_summary"].apply(
63+
lambda x: float(x.split(" ")[0]),
64+
)
65+
summary = (
66+
data[["algo", "env_name", "imit_return"]]
67+
.groupby(["algo", "env_name"])
68+
.describe()
69+
)
70+
summary.columns = summary.columns.get_level_values(1)
71+
summary = summary.reset_index()
72+
return summary
73+
74+
9175
def main() -> None: # pragma: no cover
9276
"""Run the script."""
9377
import sys

tests/scripts/test_scripts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,8 +1110,8 @@ def test_compare_to_baseline_p_values(
11101110
comparison.to_csv(tmpfile)
11111111

11121112
assert (
1113-
compare_to_baseline.compare_results_to_baseline(results_file=tmpfile)["pvalue"][
1114-
0
1115-
]
1113+
compare_to_baseline.compare_results_to_baseline(results_filename=tmpfile)[
1114+
"pvalue"
1115+
][0]
11161116
< p_value
11171117
)

0 commit comments

Comments
 (0)