-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathml_behavior_train_baseline.py
More file actions
161 lines (139 loc) · 5.06 KB
/
Copy pathml_behavior_train_baseline.py
File metadata and controls
161 lines (139 loc) · 5.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#!/usr/bin/env python3
"""Train multinomial logistic baseline on manifest meta-features (#416 Wave 2).
Requires: pip install scikit-learn
Inputs: behavior_dataset_manifest@v1 (see ``make ml-build-behavior-dataset``).
Outputs:
- behavior_logistic_export@v1.json — weights for ``processor.behavior_recognition.weights_path``
- predictions JSON for ``make ml-build-behavior-train-report``
"""
from __future__ import annotations
import argparse
import json
import math
import sys
from pathlib import Path
from typing import Any
_REPO_ROOT = Path(__file__).resolve().parent.parent
_APP = _REPO_ROOT / "app"
if _APP.is_dir() and str(_APP) not in sys.path:
sys.path.insert(0, str(_APP))
from shared.behavior_logistic_train import EXPORT_SCHEMA, fit_behavior_logistic_export
def _meta_features(row: dict[str, Any]) -> list[float]:
frame_rows = float(row.get("frame_rows") or 0)
subject_count = float(row.get("subject_count") or 0)
species = row.get("species_names") or []
nsp = float(len(species)) if isinstance(species, list) else 0.0
return [
math.log1p(max(0.0, frame_rows)),
subject_count / 20.0,
nsp / 10.0,
]
def _dominant_behavior_id(row: dict[str, Any]) -> int | None:
raw = row.get("behavior_counts") or {}
if not isinstance(raw, dict) or not raw:
return None
best_id: int | None = None
best_n = -1
for k, v in raw.items():
try:
bid = int(k)
n = int(v)
except (TypeError, ValueError):
continue
if n > best_n:
best_n = n
best_id = bid
return best_id
def train_and_export(
manifest: dict[str, Any],
*,
max_iter: int = 500,
seed: int = 42,
) -> tuple[dict[str, Any], dict[str, Any]]:
if str(manifest.get("schema") or "") != "behavior_dataset_manifest@v1":
raise ValueError("manifest schema must be behavior_dataset_manifest@v1")
tax = manifest.get("taxonomy") or []
id_to_label: dict[int, str] = {}
for row in tax:
if not isinstance(row, dict):
continue
try:
bid = int(row["id"])
except (KeyError, TypeError, ValueError):
continue
lab = str(row.get("label") or "").strip().lower()
if lab:
id_to_label[bid] = lab
if not id_to_label:
raise ValueError("empty taxonomy labels")
X_list: list[list[float]] = []
y_list: list[str] = []
for row in manifest.get("videos") or []:
if not isinstance(row, dict):
continue
dom_id = _dominant_behavior_id(row)
if dom_id is None or dom_id not in id_to_label:
continue
y_list.append(id_to_label[dom_id])
X_list.append(_meta_features(row))
export, clf = fit_behavior_logistic_export(
X_list,
y_list,
max_iter=max_iter,
seed=seed,
feature_mode="manifest_meta_v1",
extra={"manifest_dataset_id": manifest.get("dataset_id")},
)
classes = [str(c) for c in (export.get("labels") or [])]
import numpy as np
pred_rows: list[dict[str, Any]] = []
for row in manifest.get("videos") or []:
if not isinstance(row, dict):
continue
key = str(row.get("video_key") or "").strip()
if not key:
continue
xf = np.array([_meta_features(row)], dtype=np.float64)
proba = clf.predict_proba(xf)[0]
idx = int(np.argmax(proba))
prob_row = {classes[i]: round(float(proba[i]), 6) for i in range(len(classes))}
pred_rows.append(
{
"video_key": key,
"pred_label": classes[idx],
"confidence": round(float(proba[idx]), 6),
"proba": prob_row,
}
)
predictions = {"schema": "behavior_predictions@v1", "predictions": pred_rows}
return export, predictions
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--manifest", required=True)
p.add_argument("--export-out", required=True)
p.add_argument("--predictions-out", required=True)
p.add_argument("--max-iter", type=int, default=500)
p.add_argument("--seed", type=int, default=42)
return p.parse_args()
def main() -> int:
args = _parse_args()
manifest = json.loads(Path(args.manifest).read_text(encoding="utf-8"))
export, predictions = train_and_export(manifest, max_iter=args.max_iter, seed=args.seed)
Path(args.export_out).expanduser().resolve().parent.mkdir(parents=True, exist_ok=True)
Path(args.export_out).expanduser().resolve().write_text(
json.dumps(export, ensure_ascii=False, indent=2),
encoding="utf-8",
)
Path(args.predictions_out).expanduser().resolve().write_text(
json.dumps(predictions, ensure_ascii=False, indent=2),
encoding="utf-8",
)
print(
json.dumps(
{"ok": True, "export": str(Path(args.export_out).resolve()), "n_rows": len(predictions["predictions"])},
ensure_ascii=False,
)
)
return 0
if __name__ == "__main__":
raise SystemExit(main())