Skip to content

Commit 7b0f399

Browse files
committed
fix(moonshine): return 400 on audio load failure
1 parent aab5528 commit 7b0f399

4 files changed

Lines changed: 266 additions & 82 deletions

File tree

src/cpp/server/backends/moonshine_server.cpp

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,35 @@ json MoonshineServer::forward_multipart_audio_data(const std::string& audio_data
295295
LOG(DEBUG, "MoonshineServer") << "Response status: " << res.status_code << std::endl;
296296

297297
if (res.status_code != 200) {
298-
throw std::runtime_error("moonshine-server returned status " +
299-
std::to_string(res.status_code) + ": " + res.body);
298+
std::string err_msg = res.body;
299+
std::string err_type = "audio_processing_error";
300+
int status_code = res.status_code;
301+
302+
try {
303+
json error_json = json::parse(res.body);
304+
if (error_json.contains("error")) {
305+
if (error_json["error"].is_string()) {
306+
err_msg = error_json["error"].get<std::string>();
307+
} else if (error_json["error"].is_object() && error_json["error"].contains("message")) {
308+
err_msg = error_json["error"]["message"].get<std::string>();
309+
}
310+
}
311+
} catch (...) {
312+
// Keep res.body as raw error message
313+
}
314+
315+
if (status_code == 400 || (status_code == 500 && err_msg.find("Not a valid RIFF file") != std::string::npos)) {
316+
status_code = 400;
317+
err_type = "invalid_request_error";
318+
}
319+
320+
return json{
321+
{"error", {
322+
{"message", "Transcription failed: " + err_msg},
323+
{"type", err_type},
324+
{"status_code", status_code}
325+
}}
326+
};
300327
}
301328

302329
try {
@@ -322,7 +349,8 @@ json MoonshineServer::audio_transcriptions(const json& request) {
322349
return json{
323350
{"error", {
324351
{"message", std::string("Transcription failed: ") + e.what()},
325-
{"type", "audio_processing_error"}
352+
{"type", "audio_processing_error"},
353+
{"status_code", 500}
326354
}}
327355
};
328356
}

src/cpp/server/server.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2578,7 +2578,8 @@ void Server::handle_audio_transcriptions(const httplib::Request& req, httplib::R
25782578

25792579
// Check for error in response
25802580
if (response.contains("error")) {
2581-
res.status = 500;
2581+
set_error_response(response, res, 500);
2582+
return;
25822583
}
25832584

25842585
res.set_content(response.dump(), "application/json");

test/server_moonshine.py

Lines changed: 129 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,17 @@ def _generate_test_audio(self, duration_sec=3.0, sample_rate=16000):
4949
import math
5050

5151
n_samples = int(duration_sec * sample_rate)
52-
tmp_path = os.path.join(tempfile.gettempdir(), "moonshine_test_audio.wav")
52+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
53+
tmp_path = tmp.name
5354

5455
with wave.open(tmp_path, "w") as wf:
5556
wf.setnchannels(1)
5657
wf.setsampwidth(2)
5758
wf.setframerate(sample_rate)
5859
for i in range(n_samples):
59-
sample = int(math.sin(2 * math.pi * 440 * i / sample_rate) * 0.3 * 32767)
60+
sample = int(
61+
math.sin(2 * math.pi * 440 * i / sample_rate) * 0.3 * 32767
62+
)
6063
wf.writeframes(struct.pack("<h", sample))
6164

6265
return tmp_path
@@ -85,7 +88,9 @@ def test_moonshine_file_transcription(self):
8588
data = {"model": model_name, "response_format": "json"}
8689
resp = requests.post(url, files=files, data=data, timeout=60)
8790

88-
self.assertEqual(resp.status_code, 200, f"Transcription failed: {resp.text}")
91+
self.assertEqual(
92+
resp.status_code, 200, f"Transcription failed: {resp.text}"
93+
)
8994

9095
result = resp.json()
9196
self.assertIn("text", result)
@@ -145,10 +150,14 @@ async def stream() -> tuple[set, str]:
145150
frames_per_chunk = rate * chunk_ms // 1000
146151

147152
async with websockets.connect(ws_url) as ws:
148-
await ws.send(jsonlib.dumps({
149-
"type": "session.update",
150-
"session": {"model": model_name},
151-
}))
153+
await ws.send(
154+
jsonlib.dumps(
155+
{
156+
"type": "session.update",
157+
"session": {"model": model_name},
158+
}
159+
)
160+
)
152161

153162
async def reader():
154163
nonlocal final_text
@@ -171,39 +180,49 @@ async def reader():
171180
# continue for the audio that follows
172181
for _ in range(3):
173182
frames = wf.readframes(frames_per_chunk)
174-
await ws.send(jsonlib.dumps({
175-
"type": "input_audio_buffer.append",
176-
"audio": base64.b64encode(frames).decode(),
177-
}))
183+
await ws.send(
184+
jsonlib.dumps(
185+
{
186+
"type": "input_audio_buffer.append",
187+
"audio": base64.b64encode(frames).decode(),
188+
}
189+
)
190+
)
178191
await asyncio.sleep(chunk_ms / 1000)
179192
wf.rewind()
180-
await ws.send(jsonlib.dumps(
181-
{"type": "input_audio_buffer.clear"}
182-
))
193+
await ws.send(
194+
jsonlib.dumps({"type": "input_audio_buffer.clear"})
195+
)
183196
await asyncio.sleep(0.5)
184197

185198
while True:
186199
frames = wf.readframes(frames_per_chunk)
187200
if not frames:
188201
break
189-
await ws.send(jsonlib.dumps({
190-
"type": "input_audio_buffer.append",
191-
"audio": base64.b64encode(frames).decode(),
192-
}))
202+
await ws.send(
203+
jsonlib.dumps(
204+
{
205+
"type": "input_audio_buffer.append",
206+
"audio": base64.b64encode(frames).decode(),
207+
}
208+
)
209+
)
193210
await asyncio.sleep(chunk_ms / 1000)
194211

195212
# Trailing silence lets the streaming model close the line
196213
silence = b"\x00\x00" * frames_per_chunk
197214
for _ in range(15):
198-
await ws.send(jsonlib.dumps({
199-
"type": "input_audio_buffer.append",
200-
"audio": base64.b64encode(silence).decode(),
201-
}))
215+
await ws.send(
216+
jsonlib.dumps(
217+
{
218+
"type": "input_audio_buffer.append",
219+
"audio": base64.b64encode(silence).decode(),
220+
}
221+
)
222+
)
202223
await asyncio.sleep(chunk_ms / 1000)
203224

204-
await ws.send(jsonlib.dumps(
205-
{"type": "input_audio_buffer.commit"}
206-
))
225+
await ws.send(jsonlib.dumps({"type": "input_audio_buffer.commit"}))
207226
await asyncio.sleep(3)
208227
rtask.cancel()
209228
return events, final_text
@@ -281,10 +300,10 @@ def test_moonshine_thread_count(self):
281300

282301
# Find the moonshine-server subprocess PID
283302
import subprocess
303+
284304
try:
285305
output = subprocess.check_output(
286-
["pgrep", "-f", "moonshine-server"],
287-
text=True
306+
["pgrep", "-f", "moonshine-server"], text=True
288307
).strip()
289308
if not output:
290309
self.skipTest("moonshine-server process not found")
@@ -295,14 +314,95 @@ def test_moonshine_thread_count(self):
295314
if line.startswith("Threads:"):
296315
thread_count = int(line.split()[1])
297316
self.assertLessEqual(
298-
thread_count, 10,
299-
f"moonshine-server spawned {thread_count} threads (expected <= 10)"
317+
thread_count,
318+
10,
319+
f"moonshine-server spawned {thread_count} threads (expected <= 10)",
300320
)
301321
print(f"[MoonshineTest] Thread count: {thread_count}")
302322
break
303323
except Exception as e:
304324
self.skipTest(f"Could not check thread count: {e}")
305325

326+
@skip_if_unsupported("transcription")
327+
def test_moonshine_invalid_file_transcription(self):
328+
"""Test that transcribing an invalid/unsupported file returns a 400 Bad Request."""
329+
model_name = _get_moonshine_model()
330+
if not model_name or "Moonshine" not in model_name:
331+
self.skipTest("No Moonshine model configured for testing")
332+
333+
self._load_model(model_name)
334+
335+
# Create an invalid WAV file (just some text content) using NamedTemporaryFile
336+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
337+
tmp.write(b"This is not a valid RIFF/WAV file.")
338+
tmp_path = tmp.name
339+
340+
try:
341+
url = f"http://127.0.0.1:{PORT}/v1/audio/transcriptions"
342+
with open(tmp_path, "rb") as f:
343+
files = {"file": ("test.wav", f, "audio/wav")}
344+
data = {"model": model_name, "response_format": "json"}
345+
resp = requests.post(url, files=files, data=data, timeout=60)
346+
347+
# We expect a 400 Bad Request
348+
self.assertEqual(
349+
resp.status_code,
350+
400,
351+
f"Expected 400 but got {resp.status_code}: {resp.text}",
352+
)
353+
354+
result = resp.json()
355+
self.assertIn("error", result)
356+
self.assertIn("message", result["error"])
357+
self.assertEqual(result["error"]["type"], "invalid_request_error")
358+
359+
finally:
360+
if os.path.exists(tmp_path):
361+
os.remove(tmp_path)
362+
363+
@skip_if_unsupported("transcription")
364+
def test_moonshine_unsupported_ogg_transcription(self):
365+
"""Test that transcribing a valid but unsupported Ogg audio file returns a 400 Bad Request."""
366+
model_name = _get_moonshine_model()
367+
if not model_name or "Moonshine" not in model_name:
368+
self.skipTest("No Moonshine model configured for testing")
369+
370+
self._load_model(model_name)
371+
372+
import base64
373+
374+
# A tiny valid silent Ogg file fixture (180 bytes)
375+
ogg_data = base64.b64decode(
376+
"T2dnUwACAAAAAAAAAAAyzN3NAAAAAGFf2X8BM39GTEFDAQAAAWZMYUMAAAAiEgASAAAAAAAkFQrEQPAAAAAAAAAAAAAAAAAAAAAAAAAAAE9nZ1MAAAAAAAAAAAAAMszdzQEAAAD5LKCSATeEAAAzDQAAAExhdmY1NS40OC4xMDABAAAAGgAAAGVuY29kZXI9TGF2YzU1LjY5LjEwMCBmbGFjT2dnUwAEARIAAAAAAAAyzN3NAgAAAKWVljkCDAD/+GkIAAAdAAABICI="
377+
)
378+
379+
with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as tmp:
380+
tmp.write(ogg_data)
381+
tmp_path = tmp.name
382+
383+
try:
384+
url = f"http://127.0.0.1:{PORT}/v1/audio/transcriptions"
385+
with open(tmp_path, "rb") as f:
386+
files = {"file": ("test.ogg", f, "audio/ogg")}
387+
data = {"model": model_name, "response_format": "json"}
388+
resp = requests.post(url, files=files, data=data, timeout=60)
389+
390+
# Ogg is unsupported by Moonshine, so it must return 400 Bad Request
391+
self.assertEqual(
392+
resp.status_code,
393+
400,
394+
f"Expected 400 but got {resp.status_code}: {resp.text}",
395+
)
396+
397+
result = resp.json()
398+
self.assertIn("error", result)
399+
self.assertIn("message", result["error"])
400+
self.assertEqual(result["error"]["type"], "invalid_request_error")
401+
402+
finally:
403+
if os.path.exists(tmp_path):
404+
os.remove(tmp_path)
405+
306406

307407
if __name__ == "__main__":
308408
run_server_tests(MoonshineTests)

0 commit comments

Comments
 (0)