Skip to content

Commit d21a76c

Browse files
committed
Fix check_diff and format
1 parent 99e603b commit d21a76c

File tree

2 files changed

+121
-100
lines changed

2 files changed

+121
-100
lines changed

src/examples/colors/exp1/check_diff.py

Lines changed: 119 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Any
12
from scipy.stats import ttest_1samp
23
import pandas as pd
34
import numpy as np
@@ -20,66 +21,58 @@ def get_closest(points: np.ndarray, point: np.ndarray, amount: int) -> np.ndarra
2021

2122
def convert_model(
2223
model: pd.DataFrame,
23-
) -> tuple[
24-
np.ndarray,
25-
np.ndarray,
26-
np.ndarray,
27-
np.ndarray,
28-
np.ndarray,
29-
np.ndarray,
30-
np.ndarray,
31-
np.ndarray,
32-
]:
24+
) -> dict[str, dict[str, Any]]:
3325
"""
34-
Converts an input model taken from the color model `.csv` file and returns helpful arrays which are used in calculations
26+
Converts an input model taken from the color model `.csv` file and a helpful dict to use in calculations
3527
3628
Args:
3729
model (DataFrame): The DataFrame which can be directly loaded from the color model's minimized `.csv` file
3830
3931
Returns:
40-
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
41-
The following arrays, in order:
42-
- The points from optimal encoders
43-
- The points from suboptimal encoders
44-
- The points from natural language encoders
45-
- The concatination of all points (in order: optimal, suboptimal, natural)
46-
- The concatination of suboptimal and natural language encoder points
47-
- The concatination of optimal and suboptimal encoder points
48-
- The quasi-convexity of the q(m|w) distributions for all encoders in points
49-
- The quasi-convexity of the q(u|w) distributions for all encoders in points
32+
dict[str, dict[str, Any]]: Contains the following data
33+
For "optimal", "suboptimal", and "natural" has a dictionary with the following keys:
34+
- "points": All the points of encoders of that type
35+
- "check": All the points to check against in the neighbor comparison
36+
- "check_convexities": A dict with the quasi-convexity of the q(m|w) and q(u|w) distributions for the encoders in "check"
37+
- "point_convexities": A dict with the quasi-convexity of the q(m|w) and q(u|w) distributions for the encoders in "points"
5038
"""
51-
frontier = []
52-
suboptimal = []
53-
natural = []
54-
convexities_qmw = []
55-
convexities_quw = []
39+
output = {
40+
"optimal": {
41+
"points": [],
42+
"check": [],
43+
"check_convexities": {"qmw": [], "quw": []},
44+
"point_convexities": {"qmw": [], "quw": []},
45+
},
46+
"suboptimal": {
47+
"points": [],
48+
"check": [],
49+
"check_convexities": {"qmw": [], "quw": []},
50+
"point_convexities": {"qmw": [], "quw": []},
51+
},
52+
"natural": {
53+
"points": [],
54+
"check": [],
55+
"check_convexities": {"qmw": [], "quw": []},
56+
"point_convexities": {"qmw": [], "quw": []},
57+
},
58+
}
5659

5760
for _, row in model.iterrows():
58-
if row["type"] == "optimal":
59-
frontier.append([row["complexity"], row["accuracy"]])
60-
if row["type"] == "suboptimal":
61-
suboptimal.append([row["complexity"], row["accuracy"]])
62-
if row["type"] == "natural":
63-
natural.append([row["complexity"], row["accuracy"]])
64-
convexities_qmw.append(row["convexity-qmw"])
65-
convexities_quw.append(row["convexity-quw"])
66-
67-
frontier = np.array(frontier)
68-
suboptimal = np.array(suboptimal)
69-
natural = np.array(natural)
70-
points = np.concat((frontier, suboptimal, natural))
71-
check_frontier = np.concat((suboptimal, natural))
72-
check_natural = np.concat((frontier, suboptimal))
73-
return (
74-
frontier,
75-
suboptimal,
76-
natural,
77-
points,
78-
check_frontier,
79-
check_natural,
80-
convexities_qmw,
81-
convexities_quw,
82-
)
61+
for k in output.keys():
62+
if k == row["type"]:
63+
output[k]["points"].append([row["complexity"], row["accuracy"]])
64+
output[k]["point_convexities"]["quw"].append(row["convexity-quw"])
65+
output[k]["point_convexities"]["qmw"].append(row["convexity-qmw"])
66+
if k != row["type"] or row["type"] == "suboptimal":
67+
output[k]["check"].append([row["complexity"], row["accuracy"]])
68+
output[k]["check_convexities"]["quw"].append(row["convexity-quw"])
69+
output[k]["check_convexities"]["qmw"].append(row["convexity-qmw"])
70+
71+
for k in output.keys():
72+
output[k]["points"] = np.array(output[k]["points"])
73+
output[k]["check"] = np.array(output[k]["check"])
74+
75+
return output
8376

8477

8578
def get_neighbor_comparison(amount: int, model: pd.DataFrame):
@@ -90,51 +83,77 @@ def get_neighbor_comparison(amount: int, model: pd.DataFrame):
9083
amount (int): The number of neighbors to compare to.
9184
model (DataFrame): The DataFrame which can be directly loaded from the color model's minimized `.csv` file.
9285
"""
93-
_, _, _, points, check_frontier, check_natural, convexities_qmw, convexities_quw = (
94-
convert_model(model)
86+
converted_model = convert_model(model)
87+
88+
comparison_qmw = {"optimal": [], "suboptimal": [], "natural": []}
89+
comparison_quw = {"optimal": [], "suboptimal": [], "natural": []}
90+
for k in converted_model.keys():
91+
suboptimal = k == "suboptimal"
92+
for i, p in enumerate(converted_model[k]["points"]):
93+
higher_than_qmw = -0.5 if suboptimal else 0
94+
higher_than_quw = -0.5 if suboptimal else 0
95+
if suboptimal:
96+
# Check includes itself so get 11 instead of 10
97+
closest = get_closest(converted_model[k]["check"], p, amount + 1)
98+
else:
99+
closest = get_closest(converted_model[k]["check"], p, amount)
100+
for c in closest:
101+
qmw_diff = (
102+
converted_model[k]["point_convexities"]["qmw"][i]
103+
- converted_model[k]["check_convexities"]["qmw"][c]
104+
)
105+
quw_diff = (
106+
converted_model[k]["point_convexities"]["quw"][i]
107+
- converted_model[k]["check_convexities"]["quw"][c]
108+
)
109+
if qmw_diff > 0:
110+
higher_than_qmw += 1
111+
if qmw_diff == 0:
112+
higher_than_qmw += 0.5
113+
if quw_diff > 0:
114+
higher_than_quw += 1
115+
if quw_diff == 0:
116+
higher_than_quw += 0.5
117+
comparison_qmw[k].append(higher_than_qmw / amount)
118+
comparison_quw[k].append(higher_than_quw / amount)
119+
120+
total_qmw = (
121+
comparison_qmw["natural"]
122+
+ comparison_qmw["suboptimal"]
123+
+ comparison_qmw["optimal"]
124+
)
125+
print("Average comparison (q(m|w)):", sum(total_qmw) / len(total_qmw))
126+
print(
127+
"Average natural language (q(m|w)):",
128+
sum(comparison_qmw["natural"]) / len(comparison_qmw["natural"]),
129+
)
130+
print(
131+
"Average optimal encoder (q(m|w)):",
132+
sum(comparison_qmw["optimal"]) / len(comparison_qmw["optimal"]),
95133
)
96-
97-
comparison_qmw = []
98-
comparison_quw = []
99-
for i, p in enumerate(points):
100-
higher_than_qmw = 0
101-
higher_than_quw = 0
102-
if i < 1501:
103-
closest = get_closest(check_frontier, p, amount)
104-
elif i >= 2601:
105-
closest = get_closest(check_natural, p, amount)
106-
else:
107-
closest = get_closest(points, p, amount + 1)
108-
for c in closest:
109-
if i < 1501:
110-
c += 1501
111-
if c == i:
112-
continue
113-
if convexities_qmw[i] - convexities_qmw[c] > 0:
114-
higher_than_qmw += 1
115-
if convexities_qmw[i] - convexities_qmw[c] == 0:
116-
higher_than_qmw += 0.5
117-
if convexities_quw[i] - convexities_quw[c] > 0:
118-
higher_than_quw += 1
119-
if convexities_quw[i] - convexities_quw[c] == 0:
120-
higher_than_quw += 0.5
121-
comparison_qmw.append(higher_than_qmw / amount)
122-
comparison_quw.append(higher_than_quw / amount)
123-
124-
print("Average comparison (q(m|w)):", sum(comparison_qmw) / len(comparison_qmw))
125-
print("Average natural language (q(m|w)):", sum(comparison_qmw[-110:]) / 110)
126-
print("Average optimal encoder (q(m|w)):", sum(comparison_qmw[:1501]) / 1501)
127134
print(
128135
"Average suboptimal encoder (q(m|w)):",
129-
sum(comparison_qmw[1501:-110]) / len(comparison_qmw[1501:-110]),
136+
sum(comparison_qmw["suboptimal"]) / len(comparison_qmw["suboptimal"]),
130137
)
131138
print()
132-
print("Average comparison (q(u|w)):", sum(comparison_quw) / len(comparison_quw))
133-
print("Average natural language (q(u|w)):", sum(comparison_quw[-110:]) / 110)
134-
print("Average optimal encoder (q(u|w)):", sum(comparison_quw[:1501]) / 1501)
139+
140+
total_quw = (
141+
comparison_quw["natural"]
142+
+ comparison_quw["suboptimal"]
143+
+ comparison_quw["optimal"]
144+
)
145+
print("Average comparison (q(u|w)):", sum(total_quw) / len(total_quw))
146+
print(
147+
"Average natural language (q(u|w)):",
148+
sum(comparison_quw["natural"]) / len(comparison_quw["natural"]),
149+
)
150+
print(
151+
"Average optimal encoder (q(u|w)):",
152+
sum(comparison_quw["optimal"]) / len(comparison_quw["optimal"]),
153+
)
135154
print(
136155
"Average suboptimal encoder (q(u|w)):",
137-
sum(comparison_quw[1501:-110]) / len(comparison_quw[1501:-110]),
156+
sum(comparison_quw["suboptimal"]) / len(comparison_quw["suboptimal"]),
138157
)
139158

140159

@@ -147,20 +166,22 @@ def check_difference_significance(amount: int, model: pd.DataFrame):
147166
amount (int): The number of neighbors to compare to.
148167
model (DataFrame): The DataFrame which can be directly loaded from the color model's minimized `.csv` file.
149168
"""
150-
_, _, natural, _, _, check_natural, convexities_qmw, convexities_quw = (
151-
convert_model(model)
152-
)
153-
154-
offset = len(check_natural)
169+
converted = convert_model(model)
155170

156171
comparison_qmw = []
157172
comparison_quw = []
158173

159-
for i, p in enumerate(natural):
160-
closest = get_closest(check_natural, p, amount)
174+
for i, p in enumerate(converted["natural"]["points"]):
175+
closest = get_closest(converted["natural"]["check"], p, amount)
161176
for c in closest:
162-
comparison_qmw.append(convexities_qmw[i + offset] - convexities_qmw[c])
163-
comparison_quw.append(convexities_quw[i + offset] - convexities_quw[c])
177+
comparison_qmw.append(
178+
converted["natural"]["point_convexities"]["qmw"][i]
179+
- converted["natural"]["check_convexities"]["qmw"][c]
180+
)
181+
comparison_quw.append(
182+
converted["natural"]["point_convexities"]["quw"][i]
183+
- converted["natural"]["check_convexities"]["quw"][c]
184+
)
164185

165186
print(ttest_1samp(comparison_qmw, 0))
166187
print(ttest_1samp(comparison_quw, 0))

src/examples/colors/utils/minimize_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def minimize_model(name: str):
4545
"type": [],
4646
"beta": [],
4747
"optimality": [],
48-
"base_item_id": []
48+
"base_item_id": [],
4949
}
5050

5151
frontier = []
@@ -96,7 +96,7 @@ def minimize_model(name: str):
9696
df_data["optimality"].append(
9797
find_frontier_optimality(frontier, np.array([s.complexity, s.iwu]))
9898
)
99-
df_data["base_item_id"].append(i//10)
99+
df_data["base_item_id"].append(i // 10)
100100

101101
df = pd.DataFrame(data=df_data)
102102
df.index.name = "item_id"

0 commit comments

Comments
 (0)