Skip to content

Commit c9c2345

Browse files
64bitifsheldon
authored andcommitted
feat: sync audio api + fix error types (64bit#464)
* updates to CreateSpeechRequest * types::audio; streaming audio types and api * fix imports in the examples * fix import * fix api; implement trait * add audio-speech-stream example * updated CreateTranscriptionRequest * updates for CreateTranscriptionResponseJson * udpated CreateTranscriptionResponseVerboseJson * updated CreateTranscriptionResponseDiarizedJson * update types for diarized * update transcription example * streaming from form submission * add streaming example to audio-transcribe * updates for translation * update to example * update audio api groups * update audio examples * fix webhooks error type * fix errors reported by clippy * fix based on clippy * fix for: https://rust-lang.github.io/rust-clippy/master/index.html\#result_large_err * cargo fmt (cherry picked from commit eaa396b)
1 parent ed03e50 commit c9c2345

21 files changed

Lines changed: 1127 additions & 438 deletions

File tree

async-openai/src/audio.rs

Lines changed: 10 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,4 @@
1-
use bytes::Bytes;
2-
3-
use crate::{
4-
Client,
5-
config::Config,
6-
error::OpenAIError,
7-
types::{
8-
CreateSpeechRequest, CreateSpeechResponse, CreateTranscriptionRequest,
9-
CreateTranscriptionResponseJson, CreateTranscriptionResponseVerboseJson,
10-
CreateTranslationRequest, CreateTranslationResponseJson,
11-
CreateTranslationResponseVerboseJson,
12-
},
13-
};
1+
use crate::{config::Config, Client, Speech, Transcriptions, Translations};
142

153
/// Turn audio into text or text into audio.
164
/// Related guide: [Speech to text](https://platform.openai.com/docs/guides/speech-to-text)
@@ -23,89 +11,18 @@ impl<'c, C: Config> Audio<'c, C> {
2311
Self { client }
2412
}
2513

26-
/// Transcribes audio into the input language.
27-
#[crate::byot(
28-
T0 = Clone,
29-
R = serde::de::DeserializeOwned,
30-
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
31-
)]
32-
pub async fn transcribe(
33-
&self,
34-
request: CreateTranscriptionRequest,
35-
) -> Result<CreateTranscriptionResponseJson, OpenAIError> {
36-
self.client
37-
.post_form("/audio/transcriptions", request)
38-
.await
39-
}
40-
41-
/// Transcribes audio into the input language.
42-
#[crate::byot(
43-
T0 = Clone,
44-
R = serde::de::DeserializeOwned,
45-
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
46-
)]
47-
pub async fn transcribe_verbose_json(
48-
&self,
49-
request: CreateTranscriptionRequest,
50-
) -> Result<CreateTranscriptionResponseVerboseJson, OpenAIError> {
51-
self.client
52-
.post_form("/audio/transcriptions", request)
53-
.await
54-
}
55-
56-
/// Transcribes audio into the input language.
57-
pub async fn transcribe_raw(
58-
&self,
59-
request: CreateTranscriptionRequest,
60-
) -> Result<Bytes, OpenAIError> {
61-
self.client
62-
.post_form_raw("/audio/transcriptions", request)
63-
.await
14+
/// APIs in Speech group.
15+
pub fn speech(&self) -> Speech<'_, C> {
16+
Speech::new(self.client)
6417
}
6518

66-
/// Translates audio into English.
67-
#[crate::byot(
68-
T0 = Clone,
69-
R = serde::de::DeserializeOwned,
70-
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
71-
)]
72-
pub async fn translate(
73-
&self,
74-
request: CreateTranslationRequest,
75-
) -> Result<CreateTranslationResponseJson, OpenAIError> {
76-
self.client.post_form("/audio/translations", request).await
19+
/// APIs in Transcription group.
20+
pub fn transcription(&self) -> Transcriptions<'_, C> {
21+
Transcriptions::new(self.client)
7722
}
7823

79-
/// Translates audio into English.
80-
#[crate::byot(
81-
T0 = Clone,
82-
R = serde::de::DeserializeOwned,
83-
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
84-
)]
85-
pub async fn translate_verbose_json(
86-
&self,
87-
request: CreateTranslationRequest,
88-
) -> Result<CreateTranslationResponseVerboseJson, OpenAIError> {
89-
self.client.post_form("/audio/translations", request).await
90-
}
91-
92-
/// Transcribes audio into the input language.
93-
pub async fn translate_raw(
94-
&self,
95-
request: CreateTranslationRequest,
96-
) -> Result<Bytes, OpenAIError> {
97-
self.client
98-
.post_form_raw("/audio/translations", request)
99-
.await
100-
}
101-
102-
/// Generates audio from the input text.
103-
pub async fn speech(
104-
&self,
105-
request: CreateSpeechRequest,
106-
) -> Result<CreateSpeechResponse, OpenAIError> {
107-
let bytes = self.client.post_raw("/audio/speech", request).await?;
108-
109-
Ok(CreateSpeechResponse { bytes })
24+
/// APIs in Translation group.
25+
pub fn translation(&self) -> Translations<'_, C> {
26+
Translations::new(self.client)
11027
}
11128
}

async-openai/src/client.rs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,79 @@ impl<C: Config> Client<C> {
339339
.await
340340
}
341341

342+
pub(crate) async fn post_form_stream<O, F>(
343+
&self,
344+
path: &str,
345+
form: F,
346+
) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
347+
where
348+
F: Clone,
349+
Form: AsyncTryFrom<F, Error = OpenAIError>,
350+
O: DeserializeOwned + std::marker::Send + 'static,
351+
{
352+
// Build and execute request manually since multipart::Form is not Clone
353+
// and .eventsource() requires cloneability
354+
let response = self
355+
.http_client
356+
.post(self.config.url(path))
357+
.query(&self.config.query())
358+
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
359+
.headers(self.config.headers())
360+
.send()
361+
.await
362+
.map_err(OpenAIError::Reqwest)?;
363+
364+
// Check for error status
365+
if !response.status().is_success() {
366+
return Err(read_response(response).await.unwrap_err());
367+
}
368+
369+
// Convert response body to EventSource stream
370+
let stream = response
371+
.bytes_stream()
372+
.map(|result| result.map_err(std::io::Error::other));
373+
let event_stream = eventsource_stream::EventStream::new(stream);
374+
375+
// Convert EventSource stream to our expected format
376+
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
377+
378+
tokio::spawn(async move {
379+
use futures::StreamExt;
380+
let mut event_stream = std::pin::pin!(event_stream);
381+
382+
while let Some(event_result) = event_stream.next().await {
383+
match event_result {
384+
Err(e) => {
385+
if let Err(_e) = tx.send(Err(OpenAIError::StreamError(Box::new(
386+
StreamError::EventStream(e.to_string()),
387+
)))) {
388+
break;
389+
}
390+
}
391+
Ok(event) => {
392+
// eventsource_stream::Event is a struct with data field
393+
if event.data == "[DONE]" {
394+
break;
395+
}
396+
397+
let response = match serde_json::from_str::<O>(&event.data) {
398+
Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
399+
Ok(output) => Ok(output),
400+
};
401+
402+
if let Err(_e) = tx.send(response) {
403+
break;
404+
}
405+
}
406+
}
407+
}
408+
});
409+
410+
Ok(Box::pin(
411+
tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
412+
))
413+
}
414+
342415
/// Execute a HTTP request
343416
async fn execute_raw(
344417
&self,

async-openai/src/error.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ pub enum OpenAIError {
2020
FileReadError(String),
2121
/// Error on SSE streaming
2222
#[error("stream failed: {0}")]
23-
StreamError(StreamError),
23+
StreamError(Box<StreamError>),
2424
/// Error from client side validation
2525
/// or when builder fails to build request before making API call
2626
#[error("invalid args: {0}")]
@@ -35,6 +35,9 @@ pub enum StreamError {
3535
/// Error when a stream event does not match one of the expected values
3636
#[error("Unknown event: {0:#?}")]
3737
UnknownEvent(eventsource_stream::Event),
38+
/// Error from eventsource_stream when parsing SSE
39+
#[error("EventStream error: {0}")]
40+
EventStream(String),
3841
}
3942

4043
/// OpenAI API returns error object on failure

async-openai/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,12 @@ mod project_users;
168168
mod projects;
169169
mod responses;
170170
mod runs;
171+
mod speech;
171172
mod steps;
172173
mod threads;
173174
pub mod traits;
175+
mod transcriptions;
176+
mod translations;
174177
pub mod types;
175178
mod uploads;
176179
mod users;
@@ -207,8 +210,11 @@ pub use project_users::ProjectUsers;
207210
pub use projects::Projects;
208211
pub use responses::Responses;
209212
pub use runs::Runs;
213+
pub use speech::Speech;
210214
pub use steps::Steps;
211215
pub use threads::Threads;
216+
pub use transcriptions::Transcriptions;
217+
pub use translations::Translations;
212218
pub use uploads::Uploads;
213219
pub use users::Users;
214220
pub use vector_store_file_batches::VectorStoreFileBatches;

async-openai/src/speech.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
use crate::{
2+
config::Config,
3+
error::OpenAIError,
4+
types::audio::{CreateSpeechRequest, CreateSpeechResponse, SpeechResponseStream},
5+
Client,
6+
};
7+
8+
pub struct Speech<'c, C: Config> {
9+
client: &'c Client<C>,
10+
}
11+
12+
impl<'c, C: Config> Speech<'c, C> {
13+
pub fn new(client: &'c Client<C>) -> Self {
14+
Self { client }
15+
}
16+
17+
/// Generates audio from the input text.
18+
pub async fn create(
19+
&self,
20+
request: CreateSpeechRequest,
21+
) -> Result<CreateSpeechResponse, OpenAIError> {
22+
let bytes = self.client.post_raw("/audio/speech", request).await?;
23+
24+
Ok(CreateSpeechResponse { bytes })
25+
}
26+
27+
/// Generates audio from the input text in SSE stream format.
28+
#[crate::byot(
29+
T0 = serde::Serialize,
30+
R = serde::de::DeserializeOwned,
31+
stream = "true",
32+
where_clause = "R: std::marker::Send + 'static"
33+
)]
34+
#[allow(unused_mut)]
35+
pub async fn create_stream(
36+
&self,
37+
mut request: CreateSpeechRequest,
38+
) -> Result<SpeechResponseStream, OpenAIError> {
39+
#[cfg(not(feature = "byot"))]
40+
{
41+
use crate::types::audio::StreamFormat;
42+
if let Some(stream_format) = request.stream_format {
43+
if stream_format != StreamFormat::SSE {
44+
return Err(OpenAIError::InvalidArgument(
45+
"When stream_format is not SSE, use Audio::speech".into(),
46+
));
47+
}
48+
}
49+
50+
request.stream_format = Some(StreamFormat::SSE);
51+
}
52+
Ok(self.client.post_stream("/audio/speech", request).await)
53+
}
54+
}

0 commit comments

Comments
 (0)