Skip to content

Commit 3803ba2

Browse files
committed
-added command for physics entities calculation
-updated gitignore with the reporting path
1 parent 56ca55d commit 3803ba2

4 files changed

Lines changed: 185 additions & 7 deletions

File tree

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ __pycache__
33
.idea
44
**/*.sqlite3
55
**/.env
6-
.web/classifier_reports
6+
./web/classifier_reports

run-physics-control.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/usr/bin/env bash
2+
set -euo pipefail
3+
4+
cd "$(dirname "$0")/web"
5+
6+
python manage.py physics_control \
7+
--session "test-003-20260407"

web/categorizer/management/commands/evaluate_classifier.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def handle(self, *args, **options):
139139
auc_c, auc_d = self._plot_roc_curve(
140140
output_dir, tp_results, obs_results, ground_truth
141141
)
142+
self._plot_roc_curve_labeled(output_dir, tp_preds, obs_preds, results_limit)
142143
self._plot_precision_recall_curve(
143144
output_dir, tp_results, obs_results, ground_truth
144145
)
@@ -284,24 +285,74 @@ def _plot_roc_curve(self, output_dir, tp_results, obs_results, ground_truth):
284285
fpr_d, tpr_d, _ = roc_curve(labels_d, scores_d)
285286
auc_d = auc(fpr_d, tpr_d)
286287

288+
fig, ax = plt.subplots(figsize=(8, 8))
289+
ax.plot(fpr_c, tpr_c, linewidth=2, label=f"Included (AUC={auc_c:.3f})")
290+
ax.plot(fpr_d, tpr_d, linewidth=2, label=f"Excluded (AUC={auc_d:.3f})")
291+
ax.plot([0, 1], [0, 1], "k--", linewidth=1, label="Random")
292+
ax.set_xlabel("False Positive Rate")
293+
ax.set_ylabel("True Positive Rate")
294+
ax.set_title("ROC Curve (MathWorld identifier included vs. excluded)")
295+
ax.legend(loc="lower right")
296+
fig.tight_layout()
297+
fig.savefig(os.path.join(output_dir, "roc_curve.png"), dpi=150)
298+
plt.close(fig)
299+
300+
return auc_c, auc_d
301+
302+
def _plot_roc_curve_labeled(self, output_dir, tp_preds, obs_preds, limit):
303+
common_ids = sorted(set(tp_preds) & set(obs_preds))[:limit]
304+
if not common_ids:
305+
self.stdout.write(
306+
self.style.WARNING("Skipping labeled ROC: no common items")
307+
)
308+
return
309+
310+
labels = np.array([1 if tp_preds[i][0] else 0 for i in common_ids], dtype=int)
311+
scores_with = np.array(
312+
[tp_preds[i][1] / 100.0 for i in common_ids], dtype=float
313+
)
314+
scores_without = np.array(
315+
[obs_preds[i][1] / 100.0 for i in common_ids], dtype=float
316+
)
317+
318+
if len(np.unique(labels)) < 2:
319+
self.stdout.write(
320+
self.style.WARNING(
321+
"Skipping labeled ROC: include-MW-ID answers are single-class"
322+
)
323+
)
324+
return
325+
326+
fpr_w, tpr_w, _ = roc_curve(labels, scores_with)
327+
auc_w = auc(fpr_w, tpr_w)
328+
fpr_wo, tpr_wo, _ = roc_curve(labels, scores_without)
329+
auc_wo = auc(fpr_wo, tpr_wo)
330+
287331
fig, ax = plt.subplots(figsize=(8, 8))
288332
ax.plot(
289-
fpr_c, tpr_c, linewidth=2, label=f"Table C: tp+results (AUC={auc_c:.3f})"
333+
fpr_w,
334+
tpr_w,
335+
linewidth=2,
336+
label=f"With MathWorld ID (AUC={auc_w:.3f})",
290337
)
291338
ax.plot(
292-
fpr_d, tpr_d, linewidth=2, label=f"Table D: obs+results (AUC={auc_d:.3f})"
339+
fpr_wo,
340+
tpr_wo,
341+
linewidth=2,
342+
label=f"Without MathWorld ID (AUC={auc_wo:.3f})",
293343
)
294344
ax.plot([0, 1], [0, 1], "k--", linewidth=1, label="Random")
295345
ax.set_xlabel("False Positive Rate")
296346
ax.set_ylabel("True Positive Rate")
297-
ax.set_title("ROC Curve")
347+
ax.set_title(
348+
f"ROC Curve (n={len(common_ids)}, "
349+
"labels = include-MW-ID aggregated answer)"
350+
)
298351
ax.legend(loc="lower right")
299352
fig.tight_layout()
300-
fig.savefig(os.path.join(output_dir, "roc_curve.png"), dpi=150)
353+
fig.savefig(os.path.join(output_dir, "roc_curve_labeled.png"), dpi=150)
301354
plt.close(fig)
302355

303-
return auc_c, auc_d
304-
305356
def _plot_precision_recall_curve(
306357
self, output_dir, tp_results, obs_results, ground_truth
307358
):
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from concepts.models import CategorizerResult, Item
2+
from django.core.management.base import BaseCommand
3+
4+
5+
class Command(BaseCommand):
6+
help = (
7+
"Summarize the physics-concepts control experiment: distribution of "
8+
"'math' votes per item, mean confidence per group, and the full list "
9+
"of items that received a unanimous 'math' vote."
10+
)
11+
12+
def add_arguments(self, parser):
13+
parser.add_argument(
14+
"--session",
15+
type=str,
16+
required=True,
17+
help="Session name of the physics-concepts categorization run.",
18+
)
19+
20+
def handle(self, *args, **options):
21+
session = options["session"]
22+
23+
by_item = {}
24+
for item_id, answer, confidence in CategorizerResult.objects.filter(
25+
session_name=session
26+
).values_list("item_id", "result_answer", "result_confidence"):
27+
by_item.setdefault(item_id, []).append((bool(answer), int(confidence)))
28+
29+
if not by_item:
30+
self.stdout.write(
31+
self.style.ERROR(f"No CategorizerResult rows for session '{session}'")
32+
)
33+
return
34+
35+
# Per-item aggregation
36+
groups = {0: [], 1: [], 2: [], 3: []} # yes_votes -> list[(item_id, [conf])]
37+
for item_id, judgments in by_item.items():
38+
yes_votes = sum(1 for ans, _ in judgments if ans)
39+
confidences = [c for _, c in judgments]
40+
groups.setdefault(yes_votes, []).append((item_id, confidences))
41+
42+
total_items = sum(len(v) for v in groups.values())
43+
44+
self.stdout.write(
45+
f"\nPhysics control — session '{session}' — {total_items} items\n"
46+
)
47+
48+
# ----- Distribution + mean confidence per group -----
49+
rows = []
50+
for k in sorted(groups):
51+
items = groups[k]
52+
n = len(items)
53+
if n == 0:
54+
mean_conf = 0.0
55+
else:
56+
all_confs = [c for _, confs in items for c in confs]
57+
mean_conf = sum(all_confs) / len(all_confs)
58+
rows.append((k, n, mean_conf))
59+
60+
header = f"{'Judges voting math':>20} {'Items':>8} {'Mean confidence':>18}"
61+
self.stdout.write(header)
62+
self.stdout.write("-" * len(header))
63+
for k, n, mean_conf in rows:
64+
self.stdout.write(f"{k:>20} {n:>8} {mean_conf:>17.1f}")
65+
66+
# LaTeX tabular for the augmented distribution table
67+
self.stdout.write("\nLaTeX tabular (vote distribution + mean confidence):\n")
68+
self.stdout.write(
69+
"\\begin{tabular}"
70+
"{>{\\raggedleft\\arraybackslash}p{0.28\\textwidth}"
71+
">{\\raggedleft\\arraybackslash}p{0.20\\textwidth}"
72+
">{\\raggedleft\\arraybackslash}p{0.20\\textwidth}}"
73+
)
74+
self.stdout.write(" \\toprule")
75+
self.stdout.write(
76+
" Judges voting ``math'' & Number of items & " "Mean confidence \\\\"
77+
)
78+
self.stdout.write(" \\midrule")
79+
for k, n, mean_conf in rows:
80+
self.stdout.write(f" {k} & {n:>3} & {mean_conf:5.1f} \\\\")
81+
self.stdout.write(" \\bottomrule")
82+
self.stdout.write("\\end{tabular}\n")
83+
84+
# ----- Exhaustive list of items with unanimous 'math' votes -----
85+
unanimous = groups.get(3, [])
86+
if not unanimous:
87+
self.stdout.write("\nNo items received a unanimous 'math' vote.\n")
88+
return
89+
90+
unanimous_ids = [item_id for item_id, _ in unanimous]
91+
items_by_id = {i.id: i for i in Item.objects.filter(id__in=unanimous_ids)}
92+
93+
# Attach name + mean confidence per item, ordered by name
94+
enriched = []
95+
for item_id, confs in unanimous:
96+
item = items_by_id.get(item_id)
97+
name = item.name if item and item.name else f"(item #{item_id})"
98+
mean_conf = sum(confs) / len(confs) if confs else 0.0
99+
enriched.append((name, mean_conf))
100+
enriched.sort(key=lambda x: (x[0] or "").lower())
101+
102+
self.stdout.write(f"\nItems with unanimous 'math' vote ({len(enriched)}):\n")
103+
for name, mean_conf in enriched:
104+
self.stdout.write(f" - {name} (mean confidence {mean_conf:.1f})")
105+
106+
# LaTeX tabular for the unanimous-items list
107+
self.stdout.write("\nLaTeX tabular (unanimous 'math' items):\n")
108+
self.stdout.write(
109+
"\\begin{tabular}"
110+
"{>{\\raggedright\\arraybackslash}p{0.60\\textwidth}"
111+
">{\\raggedleft\\arraybackslash}p{0.20\\textwidth}}"
112+
)
113+
self.stdout.write(" \\toprule")
114+
self.stdout.write(" Concept & Mean confidence \\\\")
115+
self.stdout.write(" \\midrule")
116+
for name, mean_conf in enriched:
117+
safe = (name or "").replace("&", "\\&").replace("_", "\\_")
118+
self.stdout.write(f" {safe} & {mean_conf:5.1f} \\\\")
119+
self.stdout.write(" \\bottomrule")
120+
self.stdout.write("\\end{tabular}\n")

0 commit comments

Comments
 (0)