-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathseaborn-charts3.py
More file actions
237 lines (189 loc) · 8.22 KB
/
seaborn-charts3.py
File metadata and controls
237 lines (189 loc) · 8.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.13"
# dependencies = [
# "pandas",
# "seaborn",
# "matplotlib",
# "statsmodels",
# ]
# ///
# See https://docs.astral.sh/uv/guides/scripts/#using-a-shebang-to-create-an-executable-file
"""seaborn-charts3.py here.
At https://github.com/wilsonmar/python-samples/blob/main/seaborn-charts3.py
Create chart for the data side-by-side using plt.subplots().
"""
__last_change__ = "25-10-09 v017 + comments H (Milliseconds (Speed)) added :seaborn-charts3.py"
# Internal imports (no pip/uv add needed):
from datetime import datetime, timezone
import sys
try:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import statsmodels.multivariate.manova as manova
except Exception as e:
print(f"Python module import failed: {e}")
print("Please activate your virtual environment:\n python3 -m venv venv\n source venv/bin/activate")
sys.exit(9)
# For wall time measurements:
pgm_strt_datetimestamp = datetime.now()
def pgm_summary(std_strt_datetimestamp):
"""Print summary count of files processed and the time to do them."""
# For wall time of standard imports:
pgm_stop_datetimestamp = datetime.now()
pgm_elapsed_wall_time = pgm_stop_datetimestamp - pgm_strt_datetimestamp
print(f"SUMMARY: {pgm_elapsed_wall_time} seconds.")
# --- Configuration & Helpers ---
# Adjust figure size for side-by-side display (e.g., width doubled)
FIGURE_SIZE = (12, 7) # Width, Height in inches
FONT_SIZE = {
'title': 16,
'label': 14,
'annotate': 12,
'overlay': 18,
'datestamp': 10,
'manova': 14,
}
POINT_SIZES = (50, 200)
def add_annotations(ax, df, formatted_time, f_value, p_value, is_first_plot):
"""Add standard annotations (text, MNOVA, LLM names, trendline) to a given axis."""
# Define columns based on plot type
x_col = 'MilliSecs' if is_first_plot else 'Accuracy'
y_col = 'USD cents'
# MNOVA Results (Shown only on the first plot)
if is_first_plot:
ax.text(
0.25, 0.50, f"F={f_value:.2f} p={p_value:.2f}",
transform=ax.transAxes, fontsize=FONT_SIZE['manova'], verticalalignment='top',
bbox=dict(boxstyle="round,pad=0.3", fc='white', ec='none', alpha=0.7)
)
# Overlay Text (Not Worth It!)
ax.text(
0.90, 0.90, "Not Worth It!",
horizontalalignment='right', fontstyle='italic', verticalalignment='bottom',
transform=ax.transAxes, fontsize=FONT_SIZE['overlay'], color='#F08080'
)
# Overlay Text (Bargain!)
ax.text(
0.20, 0.20, 'Bargain!',
horizontalalignment='right', verticalalignment='bottom', fontstyle='italic',
transform=ax.transAxes, fontsize=FONT_SIZE['overlay'], color='#006400'
)
# Trendline (Regression Plot)
sns.regplot(data=df, x=x_col, y=y_col, scatter=False, ci=None,
color='grey', line_kws={'linestyle': '--'}, ax=ax)
# Dynamically label each point with the LLM name
for _, row in df.iterrows():
ax.annotate(
text=row['LLM'],
xy=(row[x_col], row[y_col]),
xytext=(7, 0), # Reduced offset for subplots
textcoords='offset points',
ha='left',
fontsize=FONT_SIZE['annotate']
)
def plot_chart(ax, df, formatted_time, f_value, p_value, is_first_plot):
"""Create one of the two charts on a specified axis (ax)."""
# Setup plot-specific titles and labels
if is_first_plot:
x_col, hue_col, title, x_label =(
'MilliSecs', 'Accuracy', '4D-LLM Eval: Cost vs Speed Scatter Plot',
'Milliseconds (Speed)'
)
else:
# Second Plot (Accuracy vs Cost, colored by MilliSecs)
x_col, hue_col, title, x_label = (
'Accuracy', 'MilliSecs', '4D-LLM Eval: Cost vs Accuracy Scatter Plot',
'Accuracy %'
)
# Main Scatter Plot
# Sorting ensures consistent point and legend order
df.sort_values("Accuracy", ascending=False, inplace=True)
sns.scatterplot(
data=df, x=x_col, y='USD cents', markers=True,
hue=hue_col, edgecolor='black', size='CoV',
sizes=POINT_SIZES, palette='RdYlGn', ax=ax
)
# x-axis limits starts from 0 to max value + buffer
if is_first_plot:
# For MilliSecs (Speed vs Cost plot)
# Set the lower limit to 0, and the upper limit to the max value plus a buffer.
max_x = df['MilliSecs'].max()
ax.set_xlim(0, max_x * 1.05) # Start at 0, add a 5% buffer at the end
#else: #to control the x axis range for the second plot
# For Accuracy (Accuracy vs Cost plot)
# Set the lower limit to 0 (since Accuracy can range from 0 to 100).
#max_acc = df['Accuracy'].max()
# Ensure the upper limit is at least 100, or the max data point + a buffer
#upper_limit = max(100, max_acc * 1.05)
#ax.set_xlim(0, upper_limit) # Start at 0, go up to max or 100
# Styling and Legend
ax.grid(axis='y', linestyle='--', alpha=0.7)
sns.despine(ax=ax, trim=True, offset=5)
# Customize the Legend (outside the plot area). Replace legend_title with None if no title is desired.
ax.legend(
bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0., title=None#legend_title
)
# Add all other annotations
add_annotations(ax, df, formatted_time, f_value, p_value, is_first_plot)
# **Position overlay for second sub plot**
if not is_first_plot:
ax.invert_xaxis() # **NEW: REVERSE X-AXIS FOR THE SECOND PLOT**
#Overlay Text (Not Worth It!)
ax.text(
0.90, 0.90, "Not Worth It!",
horizontalalignment='right', fontstyle='italic', verticalalignment='bottom',
transform=ax.transAxes, fontsize=FONT_SIZE['overlay'], color='#F08080'
)
# Overlay Text (Bargain!)
ax.text(
0.20, 0.20, 'Bargain!',
horizontalalignment='right', verticalalignment='bottom', fontstyle='italic',
transform=ax.transAxes, fontsize=FONT_SIZE['overlay'], color='#006400'
)
# Datestamp
ax.text(
# Modify (1.3, -0.08) to place relative to the second subplot's axis (axes[1]).
1.3, -.08, formatted_time,
horizontalalignment='right', verticalalignment='bottom',
transform=ax.transAxes, fontsize=FONT_SIZE['datestamp'], color='grey'
)
ax.text(
0.25, 0.50, f"F={f_value:.2f} p={p_value:.2f}",
transform=ax.transAxes, fontsize=FONT_SIZE['manova'], verticalalignment='top',
bbox=dict(boxstyle="round,pad=0.3", fc='white', ec='none', alpha=0.7)
)
# Set Titles and Labels
ax.set_title(title, fontsize=FONT_SIZE['title'])
ax.set_xlabel(x_label)
ax.set_ylabel('USD cents cost')
# --- Main Execution ---
if __name__ == '__main__':
# Data Loading
try:
df = pd.read_csv('seaborn-charts.csv')
except FileNotFoundError:
print("Error: 'seaborn-charts.csv' not found. Please ensure the file exists.")
sys.exit(1)
print("seaborn-charts3.py generating scatter plots...")
# Pre-calculations
formatted_time = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC')
# Calculate MNOVA values
dfs = df[['Accuracy','USD cents']]
manova_result = manova.MANOVA.from_formula('dfs ~ MilliSecs', data=df)
results = manova_result.mv_test()
f_value = results.results['MilliSecs']['stat']['F Value'].iloc[0]
p_value = results.results['MilliSecs']['stat']['Pr > F'].iloc[0]
# **CORE CHANGE: Create Figure and Axes for Side-by-Side Plots**
# 1 row, 2 columns. axes[0] for the first plot, axes[1] for the second.
fig, axes = plt.subplots(1, 2, figsize=FIGURE_SIZE)
sns.set_theme(style='darkgrid')
# Plot Generation
# Use .copy() to prevent modifying the DataFrame in a way that affects the other plot
plot_chart(axes[0], df.copy(), formatted_time, f_value, p_value, is_first_plot=True)
plot_chart(axes[1], df.copy(), formatted_time, f_value, p_value, is_first_plot=False)
# Final Display
plt.tight_layout(rect=[0, 0, 1, 1]) # Adjusts plots to fit figure, accounting for legends
plt.show()
pgm_summary(pgm_strt_datetimestamp)