|
1 | 1 | import json |
2 | 2 | import os |
| 3 | +import sys |
3 | 4 | import matplotlib.pyplot as plt |
4 | 5 | import pandas as pd |
5 | 6 |
|
| 7 | +# ruff: noqa: F401 |
| 8 | +from openworm_ai.utils.llms import ( |
| 9 | + LLM_OLLAMA_LLAMA32_1B, |
| 10 | + LLM_OLLAMA_LLAMA32_3B, |
| 11 | + LLM_GPT4o, |
| 12 | + LLM_GEMINI_2F, |
| 13 | + LLM_CLAUDE37, |
| 14 | + LLM_GPT35, |
| 15 | + LLM_OLLAMA_PHI4, |
| 16 | + LLM_OLLAMA_GEMMA2, |
| 17 | + LLM_OLLAMA_GEMMA, |
| 18 | + LLM_OLLAMA_QWEN, |
| 19 | + LLM_OLLAMA_TINYLLAMA, |
| 20 | + ask_question_get_response, |
| 21 | +) |
| 22 | + |
6 | 23 | # Define model parameters (LLM parameter sizes in billions) |
7 | 24 | llm_parameters = { |
| 25 | + LLM_GPT4o: 1760, |
| 26 | + LLM_GPT35: 175, |
8 | 27 | "GPT3.5": 20, |
9 | 28 | "Phi4": 14, |
10 | 29 | "Gemma2": 9, |
11 | 30 | "Gemma": 7, |
12 | 31 | "Qwen": 4, |
13 | 32 | "Llama3.2": 1, |
14 | | - "TinyLlama":1.1, |
| 33 | + "TinyLlama": 1.1, |
15 | 34 | "GPT4o": 1760, |
16 | 35 | "Gemini": 500, |
17 | | - "Claude 3.5 Sonnet": 175 |
| 36 | + "Claude 3.5 Sonnet": 175, |
18 | 37 | } |
19 | 38 |
|
20 | 39 | # Define model distributors for coloring |
21 | 40 | model_distributors = { |
| 41 | + LLM_GPT4o: "OpenAI", |
| 42 | + LLM_GPT35: "OpenAI", |
22 | 43 | "GPT3.5": "OpenAI", |
23 | 44 | "GPT4o": "OpenAI", |
24 | 45 | "Phi4": "Microsoft", |
|
28 | 49 | "Claude 3.5 Sonnet": "Anthropic", |
29 | 50 | "Qwen": "Alibaba", |
30 | 51 | "Llama3.2": "Meta", |
31 | | - "TinyLlama":"Open Source" |
| 52 | + "TinyLlama": "Open Source", |
32 | 53 | } |
33 | 54 |
|
34 | 55 | # Define quiz categories and corresponding file paths |
35 | 56 | file_paths = { |
36 | | - #"General Knowledge": "openworm_ai/quiz/scores/general/llm_scores_general_24-02-25.json", |
37 | | - #"Science": "openworm_ai/quiz/scores/science/llm_scores_science_24-02-25.json", |
38 | | - #"C. Elegans": "openworm_ai/quiz/scores/celegans/llm_scores_celegans_24-02-25.json", |
39 | | - "RAG":"openworm_ai/quiz/scores/rag/llm_scores_rag_16-03-25_2.json" |
| 57 | + # "General Knowledge": "openworm_ai/quiz/scores/general/llm_scores_general_24-02-25.json", |
| 58 | + # "Science": "openworm_ai/quiz/scores/science/llm_scores_science_24-02-25.json", |
| 59 | + # "C. Elegans": "openworm_ai/quiz/scores/celegans/llm_scores_celegans_24-02-25.json", |
| 60 | + "RAG": "openworm_ai/quiz/scores/rag/llm_scores_rag_16-03-25_2.json" |
40 | 61 | } |
41 | 62 |
|
42 | 63 | # Folder to save figures |
|
51 | 72 | "Microsoft": "purple", |
52 | 73 | "Alibaba": "orange", |
53 | 74 | "Meta": "cyan", |
54 | | - "Open Source":"yellow" |
| 75 | + "Open Source": "yellow", |
55 | 76 | } |
56 | 77 |
|
57 | 78 | # Process each quiz category |
58 | 79 | for category, file_path in file_paths.items(): |
59 | | - save_path = os.path.join(figures_folder, f"llm_accuracy_vs_parameters_{category.replace(' ', '_').lower()}.png") |
| 80 | + save_path = os.path.join( |
| 81 | + figures_folder, |
| 82 | + f"llm_accuracy_vs_parameters_{category.replace(' ', '_').lower()}.png", |
| 83 | + ) |
60 | 84 |
|
61 | 85 | # Check if the file exists |
62 | 86 | if not os.path.exists(file_path): |
63 | | - print(f"⚠️ Warning: File not found - {file_path}. Skipping this category.") |
| 87 | + print(f"Warning: File not found - {file_path}. Skipping this category.") |
64 | 88 | continue |
65 | 89 |
|
66 | 90 | # Load JSON data |
|
72 | 96 | for result in data.get("Results", []): # Use .get() to avoid KeyError |
73 | 97 | for key in llm_parameters: |
74 | 98 | if key.lower() in result["LLM"].lower(): |
75 | | - category_results.append({ |
76 | | - "Model": key, |
77 | | - "Accuracy (%)": result["Accuracy (%)"], |
78 | | - "Parameters (B)": llm_parameters[key], |
79 | | - "Distributor": model_distributors.get(key, "Unknown") |
80 | | - }) |
| 99 | + category_results.append( |
| 100 | + { |
| 101 | + "Model": key, |
| 102 | + "Accuracy (%)": result["Accuracy (%)"], |
| 103 | + "Parameters (B)": llm_parameters[key], |
| 104 | + "Distributor": model_distributors.get(key, "Unknown"), |
| 105 | + } |
| 106 | + ) |
81 | 107 | break |
82 | 108 |
|
83 | 109 | # Skip if no data |
84 | 110 | if not category_results: |
85 | | - print(f"⚠️ No valid results found in {file_path}. Skipping...") |
| 111 | + print(f"No valid results found in {file_path}. Skipping...") |
86 | 112 | continue |
87 | 113 |
|
88 | 114 | # Convert to DataFrame |
|
94 | 120 | # Scatter plot with model labels, colored by distributor |
95 | 121 | for distributor, color in distributor_colors.items(): |
96 | 122 | subset = df[df["Distributor"] == distributor] |
97 | | - plt.scatter(subset["Parameters (B)"], subset["Accuracy (%)"], s=100, color=color, label=distributor, edgecolor="black") |
| 123 | + plt.scatter( |
| 124 | + subset["Parameters (B)"], |
| 125 | + subset["Accuracy (%)"], |
| 126 | + s=100, |
| 127 | + color=color, |
| 128 | + label=distributor, |
| 129 | + edgecolor="black", |
| 130 | + ) |
98 | 131 |
|
99 | 132 | # Add model labels to each point |
100 | 133 | for i, row in df.iterrows(): |
101 | | - plt.text(row["Parameters (B)"], row["Accuracy (%)"], row["Model"], fontsize=10, ha="right", va="bottom") |
| 134 | + plt.text( |
| 135 | + row["Parameters (B)"], |
| 136 | + row["Accuracy (%)"], |
| 137 | + row["Model"], |
| 138 | + fontsize=10, |
| 139 | + ha="right", |
| 140 | + va="bottom", |
| 141 | + ) |
102 | 142 |
|
103 | 143 | # Log scale for x-axis (model parameters) |
104 | 144 | plt.xscale("log") |
|
113 | 153 | # Save figure |
114 | 154 | plt.legend() |
115 | 155 | plt.savefig(save_path) |
116 | | - print(f"✅ Saved plot: {save_path}") |
117 | | - plt.show() |
| 156 | + print(f"Saved plot: {save_path}") |
| 157 | + if "-nogui" not in sys.argv: |
| 158 | + plt.show() |
0 commit comments