Skip to content

Commit 4be2be3

Browse files
lalaluneclaude
andcommitted
wip(O-turn-intl): add multilingual smoke-test harness for the exported ONNX
Hand-crafted complete-utterance / mid-utterance-prefix pairs in English, Spanish, Japanese, German, Mandarin, and French. The brief required en + 2 non-en samples; we cover 6 LiveKit-supported locales so the report shows whether intl coverage holds (vs. catastrophic forgetting on en or one of the CJK scripts). Scoring matches probabilityFromOnnxOutput in eot-classifier.ts: P(EOU) = softmax(logits[:, last_real_pos, :])[<|im_end|>]; exit 0 when every complete utterance scores ≥ threshold and every prefix scores < threshold. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 6d581ae commit 4be2be3

1 file changed

Lines changed: 254 additions & 0 deletions

File tree

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
#!/usr/bin/env python3
2+
"""Smoke-test the multilingual ONNX turn detector on hand-crafted samples.
3+
4+
Per the O-turn-intl brief: smoke-test the exported INT8 ONNX on English
5+
plus two non-English samples (Spanish, Japanese), covering complete
6+
utterances and mid-utterance prefixes for each. Runs against the same
7+
scoring path as ``probabilityFromOnnxOutput`` in the runtime:
8+
9+
P(EOU) = softmax(logits[:, last_real_pos, :])[<|im_end|>]
10+
11+
Exit code:
12+
0 — every complete utterance scored ≥ ``--decision-threshold`` and
13+
every prefix scored < ``--decision-threshold``.
14+
1 — at least one classification disagreed.
15+
16+
The summary JSON (saved to ``--report``) records the raw probability
17+
per row so we can chart the margin between complete/incomplete.
18+
"""
19+
20+
from __future__ import annotations
21+
22+
import argparse
23+
import json
24+
import sys
25+
from dataclasses import dataclass
26+
from pathlib import Path
27+
from typing import Any, Final
28+
29+
30+
LIVEKIT_IM_END_TOKEN: Final[str] = "<|im_end|>"
31+
32+
# Hand-crafted bilingual EOU / non-EOU pairs. Each row covers a single
33+
# "complete utterance vs. random mid-utterance prefix" comparison.
34+
# Spanish + Japanese are required by the brief; we add German + Mandarin
35+
# + French as bonus coverage because the LiveKit base model already
36+
# claims support there.
37+
SMOKE_CASES: Final[tuple[dict[str, Any], ...]] = (
38+
{
39+
"lang": "en",
40+
"complete": "Can you please tell me what time the meeting starts.",
41+
"prefix": "Can you please tell me what",
42+
},
43+
{
44+
"lang": "en",
45+
"complete": "I'm done speaking, your turn.",
46+
"prefix": "I'm done",
47+
},
48+
{
49+
"lang": "es",
50+
"complete": "¿Me puedes decir a qué hora empieza la reunión?",
51+
"prefix": "¿Me puedes decir a qué",
52+
},
53+
{
54+
"lang": "es",
55+
"complete": "He terminado de hablar, te toca.",
56+
"prefix": "He terminado",
57+
},
58+
{
59+
"lang": "ja",
60+
"complete": "会議は何時に始まりますか?",
61+
"prefix": "会議は何時",
62+
},
63+
{
64+
"lang": "ja",
65+
"complete": "もう話し終わりました、どうぞ。",
66+
"prefix": "もう話",
67+
},
68+
{
69+
"lang": "de",
70+
"complete": "Können Sie mir bitte sagen, wann das Meeting beginnt?",
71+
"prefix": "Können Sie mir bitte",
72+
},
73+
{
74+
"lang": "zh",
75+
"complete": "请问会议什么时候开始?",
76+
"prefix": "请问会议",
77+
},
78+
{
79+
"lang": "fr",
80+
"complete": "Pouvez-vous me dire à quelle heure commence la réunion ?",
81+
"prefix": "Pouvez-vous me dire à",
82+
},
83+
)
84+
85+
86+
@dataclass
87+
class SmokeRow:
88+
lang: str
89+
text: str
90+
expected: int # 1 = complete (EOU), 0 = prefix
91+
probability: float
92+
93+
def predicted(self, threshold: float) -> int:
94+
return 1 if self.probability >= threshold else 0
95+
96+
97+
def _format_livekit_prompt(tokenizer: Any, transcript: str) -> str:
98+
templated = tokenizer.apply_chat_template(
99+
[{"role": "user", "content": transcript}],
100+
add_generation_prompt=False,
101+
tokenize=False,
102+
add_special_tokens=False,
103+
)
104+
ix = templated.rfind(LIVEKIT_IM_END_TOKEN)
105+
if ix >= 0:
106+
templated = templated[:ix]
107+
return templated
108+
109+
110+
def _resolve_im_end_id(tokenizer: Any) -> int:
111+
ids = tokenizer(LIVEKIT_IM_END_TOKEN, add_special_tokens=False)["input_ids"]
112+
if not ids:
113+
raise SystemExit("tokenizer did not produce an <|im_end|> id")
114+
return int(ids[0])
115+
116+
117+
def smoke_test(
118+
*,
119+
model_path: Path,
120+
tokenizer_path: Path,
121+
decision_threshold: float = 0.5,
122+
cases: tuple[dict[str, Any], ...] = SMOKE_CASES,
123+
) -> dict[str, Any]:
124+
"""Run the smoke set against the fine-tuned ONNX.
125+
126+
Returns a dict with::
127+
128+
{
129+
"passed": bool,
130+
"decision_threshold": float,
131+
"rows": list[{"lang", "text", "expected", "probability", "predicted"}],
132+
"summary": {
133+
"<lang>": {"complete": [float], "prefix": [float], "passed": bool},
134+
...
135+
}
136+
}
137+
"""
138+
try:
139+
import numpy as np
140+
import onnxruntime
141+
from transformers import AutoTokenizer
142+
except ModuleNotFoundError as exc:
143+
raise SystemExit(
144+
"onnxruntime + transformers required for smoke test"
145+
) from exc
146+
147+
tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
148+
if tokenizer.pad_token_id is None:
149+
tokenizer.pad_token = tokenizer.eos_token
150+
im_end_id = _resolve_im_end_id(tokenizer)
151+
session = onnxruntime.InferenceSession(
152+
str(model_path), providers=["CPUExecutionProvider"],
153+
)
154+
155+
rows: list[SmokeRow] = []
156+
157+
def _score(transcript: str) -> float:
158+
prompt = _format_livekit_prompt(tokenizer, transcript)
159+
encoded = tokenizer(
160+
prompt,
161+
return_tensors="np",
162+
max_length=128,
163+
truncation=True,
164+
add_special_tokens=False,
165+
)
166+
outputs = session.run(
167+
None, {"input_ids": encoded["input_ids"].astype("int64")}
168+
)
169+
logits = outputs[0][0, -1, :].astype("float64")
170+
logits = logits - logits.max()
171+
probs = np.exp(logits) / np.exp(logits).sum()
172+
return float(probs[im_end_id])
173+
174+
for case in cases:
175+
lang = case["lang"]
176+
rows.append(
177+
SmokeRow(
178+
lang=lang,
179+
text=case["complete"],
180+
expected=1,
181+
probability=_score(case["complete"]),
182+
)
183+
)
184+
rows.append(
185+
SmokeRow(
186+
lang=lang,
187+
text=case["prefix"],
188+
expected=0,
189+
probability=_score(case["prefix"]),
190+
)
191+
)
192+
193+
summary: dict[str, dict[str, Any]] = {}
194+
for row in rows:
195+
bucket = summary.setdefault(
196+
row.lang, {"complete": [], "prefix": [], "passed": True},
197+
)
198+
if row.expected == 1:
199+
bucket["complete"].append(round(row.probability, 6))
200+
else:
201+
bucket["prefix"].append(round(row.probability, 6))
202+
if row.predicted(decision_threshold) != row.expected:
203+
bucket["passed"] = False
204+
205+
all_passed = all(b["passed"] for b in summary.values())
206+
207+
return {
208+
"passed": all_passed,
209+
"decision_threshold": decision_threshold,
210+
"rows": [
211+
{
212+
"lang": r.lang,
213+
"text": r.text,
214+
"expected": r.expected,
215+
"probability": round(r.probability, 6),
216+
"predicted": r.predicted(decision_threshold),
217+
}
218+
for r in rows
219+
],
220+
"summary": summary,
221+
}
222+
223+
224+
def main(argv: list[str] | None = None) -> int:
225+
ap = argparse.ArgumentParser(description=__doc__)
226+
ap.add_argument("--model", required=True, type=Path)
227+
ap.add_argument(
228+
"--tokenizer",
229+
required=True,
230+
type=Path,
231+
help="Directory containing tokenizer.json + sidecars.",
232+
)
233+
ap.add_argument("--report", type=Path, default=None)
234+
ap.add_argument("--decision-threshold", type=float, default=0.5)
235+
args = ap.parse_args(sys.argv[1:] if argv is None else argv)
236+
237+
report = smoke_test(
238+
model_path=args.model,
239+
tokenizer_path=args.tokenizer,
240+
decision_threshold=args.decision_threshold,
241+
)
242+
if args.report:
243+
args.report.parent.mkdir(parents=True, exist_ok=True)
244+
args.report.write_text(
245+
json.dumps(report, indent=2, sort_keys=False, ensure_ascii=False)
246+
+ "\n",
247+
encoding="utf-8",
248+
)
249+
print(json.dumps(report, indent=2, sort_keys=False, ensure_ascii=False))
250+
return 0 if report["passed"] else 1
251+
252+
253+
if __name__ == "__main__":
254+
raise SystemExit(main())

0 commit comments

Comments
 (0)