1+ from typing import Any
12from scipy .stats import ttest_1samp
23import pandas as pd
34import numpy as np
@@ -20,66 +21,58 @@ def get_closest(points: np.ndarray, point: np.ndarray, amount: int) -> np.ndarra
2021
2122def convert_model (
2223 model : pd .DataFrame ,
23- ) -> tuple [
24- np .ndarray ,
25- np .ndarray ,
26- np .ndarray ,
27- np .ndarray ,
28- np .ndarray ,
29- np .ndarray ,
30- np .ndarray ,
31- np .ndarray ,
32- ]:
24+ ) -> dict [str , dict [str , Any ]]:
3325 """
34- Converts an input model taken from the color model `.csv` file and returns helpful arrays which are used in calculations
26+ Converts an input model taken from the color model `.csv` file and a helpful dict to use in calculations
3527
3628 Args:
3729 model (DataFrame): The DataFrame which can be directly loaded from the color model's minimized `.csv` file
3830
3931 Returns:
40- tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
41- The following arrays, in order:
42- - The points from optimal encoders
43- - The points from suboptimal encoders
44- - The points from natural language encoders
45- - The concatination of all points (in order: optimal, suboptimal, natural)
46- - The concatination of suboptimal and natural language encoder points
47- - The concatination of optimal and suboptimal encoder points
48- - The quasi-convexity of the q(m|w) distributions for all encoders in points
49- - The quasi-convexity of the q(u|w) distributions for all encoders in points
32+ dict[str, dict[str, Any]]: Contains the following data
33+ For "optimal", "suboptimal", and "natural" has a dictionary with the following keys:
34+ - "points": All the points of encoders of that type
35+ - "check": All the points to check against in the neighbor comparison
36+ - "check_convexities": A dict with the quasi-convexity of the q(m|w) and q(u|w) distributions for the encoders in "check"
37+ - "point_convexities": A dict with the quasi-convexity of the q(m|w) and q(u|w) distributions for the encoders in "points"
5038 """
51- frontier = []
52- suboptimal = []
53- natural = []
54- convexities_qmw = []
55- convexities_quw = []
39+ output = {
40+ "optimal" : {
41+ "points" : [],
42+ "check" : [],
43+ "check_convexities" : {"qmw" : [], "quw" : []},
44+ "point_convexities" : {"qmw" : [], "quw" : []},
45+ },
46+ "suboptimal" : {
47+ "points" : [],
48+ "check" : [],
49+ "check_convexities" : {"qmw" : [], "quw" : []},
50+ "point_convexities" : {"qmw" : [], "quw" : []},
51+ },
52+ "natural" : {
53+ "points" : [],
54+ "check" : [],
55+ "check_convexities" : {"qmw" : [], "quw" : []},
56+ "point_convexities" : {"qmw" : [], "quw" : []},
57+ },
58+ }
5659
5760 for _ , row in model .iterrows ():
58- if row ["type" ] == "optimal" :
59- frontier .append ([row ["complexity" ], row ["accuracy" ]])
60- if row ["type" ] == "suboptimal" :
61- suboptimal .append ([row ["complexity" ], row ["accuracy" ]])
62- if row ["type" ] == "natural" :
63- natural .append ([row ["complexity" ], row ["accuracy" ]])
64- convexities_qmw .append (row ["convexity-qmw" ])
65- convexities_quw .append (row ["convexity-quw" ])
66-
67- frontier = np .array (frontier )
68- suboptimal = np .array (suboptimal )
69- natural = np .array (natural )
70- points = np .concat ((frontier , suboptimal , natural ))
71- check_frontier = np .concat ((suboptimal , natural ))
72- check_natural = np .concat ((frontier , suboptimal ))
73- return (
74- frontier ,
75- suboptimal ,
76- natural ,
77- points ,
78- check_frontier ,
79- check_natural ,
80- convexities_qmw ,
81- convexities_quw ,
82- )
61+ for k in output .keys ():
62+ if k == row ["type" ]:
63+ output [k ]["points" ].append ([row ["complexity" ], row ["accuracy" ]])
64+ output [k ]["point_convexities" ]["quw" ].append (row ["convexity-quw" ])
65+ output [k ]["point_convexities" ]["qmw" ].append (row ["convexity-qmw" ])
66+ if k != row ["type" ] or row ["type" ] == "suboptimal" :
67+ output [k ]["check" ].append ([row ["complexity" ], row ["accuracy" ]])
68+ output [k ]["check_convexities" ]["quw" ].append (row ["convexity-quw" ])
69+ output [k ]["check_convexities" ]["qmw" ].append (row ["convexity-qmw" ])
70+
71+ for k in output .keys ():
72+ output [k ]["points" ] = np .array (output [k ]["points" ])
73+ output [k ]["check" ] = np .array (output [k ]["check" ])
74+
75+ return output
8376
8477
8578def get_neighbor_comparison (amount : int , model : pd .DataFrame ):
@@ -90,51 +83,77 @@ def get_neighbor_comparison(amount: int, model: pd.DataFrame):
9083 amount (int): The number of neighbors to compare to.
9184 model (DataFrame): The DataFrame which can be directly loaded from the color model's minimized `.csv` file.
9285 """
93- _ , _ , _ , points , check_frontier , check_natural , convexities_qmw , convexities_quw = (
94- convert_model (model )
86+ converted_model = convert_model (model )
87+
88+ comparison_qmw = {"optimal" : [], "suboptimal" : [], "natural" : []}
89+ comparison_quw = {"optimal" : [], "suboptimal" : [], "natural" : []}
90+ for k in converted_model .keys ():
91+ suboptimal = k == "suboptimal"
92+ for i , p in enumerate (converted_model [k ]["points" ]):
93+ higher_than_qmw = - 0.5 if suboptimal else 0
94+ higher_than_quw = - 0.5 if suboptimal else 0
95+ if suboptimal :
96+ # Check includes itself so get 11 instead of 10
97+ closest = get_closest (converted_model [k ]["check" ], p , amount + 1 )
98+ else :
99+ closest = get_closest (converted_model [k ]["check" ], p , amount )
100+ for c in closest :
101+ qmw_diff = (
102+ converted_model [k ]["point_convexities" ]["qmw" ][i ]
103+ - converted_model [k ]["check_convexities" ]["qmw" ][c ]
104+ )
105+ quw_diff = (
106+ converted_model [k ]["point_convexities" ]["quw" ][i ]
107+ - converted_model [k ]["check_convexities" ]["quw" ][c ]
108+ )
109+ if qmw_diff > 0 :
110+ higher_than_qmw += 1
111+ if qmw_diff == 0 :
112+ higher_than_qmw += 0.5
113+ if quw_diff > 0 :
114+ higher_than_quw += 1
115+ if quw_diff == 0 :
116+ higher_than_quw += 0.5
117+ comparison_qmw [k ].append (higher_than_qmw / amount )
118+ comparison_quw [k ].append (higher_than_quw / amount )
119+
120+ total_qmw = (
121+ comparison_qmw ["natural" ]
122+ + comparison_qmw ["suboptimal" ]
123+ + comparison_qmw ["optimal" ]
124+ )
125+ print ("Average comparison (q(m|w)):" , sum (total_qmw ) / len (total_qmw ))
126+ print (
127+ "Average natural language (q(m|w)):" ,
128+ sum (comparison_qmw ["natural" ]) / len (comparison_qmw ["natural" ]),
129+ )
130+ print (
131+ "Average optimal encoder (q(m|w)):" ,
132+ sum (comparison_qmw ["optimal" ]) / len (comparison_qmw ["optimal" ]),
95133 )
96-
97- comparison_qmw = []
98- comparison_quw = []
99- for i , p in enumerate (points ):
100- higher_than_qmw = 0
101- higher_than_quw = 0
102- if i < 1501 :
103- closest = get_closest (check_frontier , p , amount )
104- elif i >= 2601 :
105- closest = get_closest (check_natural , p , amount )
106- else :
107- closest = get_closest (points , p , amount + 1 )
108- for c in closest :
109- if i < 1501 :
110- c += 1501
111- if c == i :
112- continue
113- if convexities_qmw [i ] - convexities_qmw [c ] > 0 :
114- higher_than_qmw += 1
115- if convexities_qmw [i ] - convexities_qmw [c ] == 0 :
116- higher_than_qmw += 0.5
117- if convexities_quw [i ] - convexities_quw [c ] > 0 :
118- higher_than_quw += 1
119- if convexities_quw [i ] - convexities_quw [c ] == 0 :
120- higher_than_quw += 0.5
121- comparison_qmw .append (higher_than_qmw / amount )
122- comparison_quw .append (higher_than_quw / amount )
123-
124- print ("Average comparison (q(m|w)):" , sum (comparison_qmw ) / len (comparison_qmw ))
125- print ("Average natural language (q(m|w)):" , sum (comparison_qmw [- 110 :]) / 110 )
126- print ("Average optimal encoder (q(m|w)):" , sum (comparison_qmw [:1501 ]) / 1501 )
127134 print (
128135 "Average suboptimal encoder (q(m|w)):" ,
129- sum (comparison_qmw [1501 : - 110 ]) / len (comparison_qmw [1501 : - 110 ]),
136+ sum (comparison_qmw ["suboptimal" ]) / len (comparison_qmw ["suboptimal" ]),
130137 )
131138 print ()
132- print ("Average comparison (q(u|w)):" , sum (comparison_quw ) / len (comparison_quw ))
133- print ("Average natural language (q(u|w)):" , sum (comparison_quw [- 110 :]) / 110 )
134- print ("Average optimal encoder (q(u|w)):" , sum (comparison_quw [:1501 ]) / 1501 )
139+
140+ total_quw = (
141+ comparison_quw ["natural" ]
142+ + comparison_quw ["suboptimal" ]
143+ + comparison_quw ["optimal" ]
144+ )
145+ print ("Average comparison (q(u|w)):" , sum (total_quw ) / len (total_quw ))
146+ print (
147+ "Average natural language (q(u|w)):" ,
148+ sum (comparison_quw ["natural" ]) / len (comparison_quw ["natural" ]),
149+ )
150+ print (
151+ "Average optimal encoder (q(u|w)):" ,
152+ sum (comparison_quw ["optimal" ]) / len (comparison_quw ["optimal" ]),
153+ )
135154 print (
136155 "Average suboptimal encoder (q(u|w)):" ,
137- sum (comparison_quw [1501 : - 110 ]) / len (comparison_quw [1501 : - 110 ]),
156+ sum (comparison_quw ["suboptimal" ]) / len (comparison_quw ["suboptimal" ]),
138157 )
139158
140159
@@ -147,20 +166,22 @@ def check_difference_significance(amount: int, model: pd.DataFrame):
147166 amount (int): The number of neighbors to compare to.
148167 model (DataFrame): The DataFrame which can be directly loaded from the color model's minimized `.csv` file.
149168 """
150- _ , _ , natural , _ , _ , check_natural , convexities_qmw , convexities_quw = (
151- convert_model (model )
152- )
153-
154- offset = len (check_natural )
169+ converted = convert_model (model )
155170
156171 comparison_qmw = []
157172 comparison_quw = []
158173
159- for i , p in enumerate (natural ):
160- closest = get_closest (check_natural , p , amount )
174+ for i , p in enumerate (converted [ " natural" ][ "points" ] ):
175+ closest = get_closest (converted [ "natural" ][ "check" ] , p , amount )
161176 for c in closest :
162- comparison_qmw .append (convexities_qmw [i + offset ] - convexities_qmw [c ])
163- comparison_quw .append (convexities_quw [i + offset ] - convexities_quw [c ])
177+ comparison_qmw .append (
178+ converted ["natural" ]["point_convexities" ]["qmw" ][i ]
179+ - converted ["natural" ]["check_convexities" ]["qmw" ][c ]
180+ )
181+ comparison_quw .append (
182+ converted ["natural" ]["point_convexities" ]["quw" ][i ]
183+ - converted ["natural" ]["check_convexities" ]["quw" ][c ]
184+ )
164185
165186 print (ttest_1samp (comparison_qmw , 0 ))
166187 print (ttest_1samp (comparison_quw , 0 ))
0 commit comments