Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 10 additions & 93 deletions async-openai/src/audio.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,4 @@
use bytes::Bytes;

use crate::{
config::Config,
error::OpenAIError,
types::{
CreateSpeechRequest, CreateSpeechResponse, CreateTranscriptionRequest,
CreateTranscriptionResponseJson, CreateTranscriptionResponseVerboseJson,
CreateTranslationRequest, CreateTranslationResponseJson,
CreateTranslationResponseVerboseJson,
},
Client,
};
use crate::{config::Config, Client, Speech, Transcriptions, Translations};

/// Turn audio into text or text into audio.
/// Related guide: [Speech to text](https://platform.openai.com/docs/guides/speech-to-text)
Expand All @@ -23,89 +11,18 @@ impl<'c, C: Config> Audio<'c, C> {
Self { client }
}

/// Transcribes audio into the input language.
#[crate::byot(
T0 = Clone,
R = serde::de::DeserializeOwned,
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
)]
pub async fn transcribe(
&self,
request: CreateTranscriptionRequest,
) -> Result<CreateTranscriptionResponseJson, OpenAIError> {
self.client
.post_form("/audio/transcriptions", request)
.await
}

/// Transcribes audio into the input language.
#[crate::byot(
T0 = Clone,
R = serde::de::DeserializeOwned,
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
)]
pub async fn transcribe_verbose_json(
&self,
request: CreateTranscriptionRequest,
) -> Result<CreateTranscriptionResponseVerboseJson, OpenAIError> {
self.client
.post_form("/audio/transcriptions", request)
.await
}

/// Transcribes audio into the input language.
pub async fn transcribe_raw(
&self,
request: CreateTranscriptionRequest,
) -> Result<Bytes, OpenAIError> {
self.client
.post_form_raw("/audio/transcriptions", request)
.await
/// APIs in Speech group.
pub fn speech(&self) -> Speech<'_, C> {
Speech::new(self.client)
}

/// Translates audio into English.
#[crate::byot(
T0 = Clone,
R = serde::de::DeserializeOwned,
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
)]
pub async fn translate(
&self,
request: CreateTranslationRequest,
) -> Result<CreateTranslationResponseJson, OpenAIError> {
self.client.post_form("/audio/translations", request).await
/// APIs in Transcription group.
pub fn transcription(&self) -> Transcriptions<'_, C> {
Transcriptions::new(self.client)
}

/// Translates audio into English.
#[crate::byot(
T0 = Clone,
R = serde::de::DeserializeOwned,
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
)]
pub async fn translate_verbose_json(
&self,
request: CreateTranslationRequest,
) -> Result<CreateTranslationResponseVerboseJson, OpenAIError> {
self.client.post_form("/audio/translations", request).await
}

/// Transcribes audio into the input language.
pub async fn translate_raw(
&self,
request: CreateTranslationRequest,
) -> Result<Bytes, OpenAIError> {
self.client
.post_form_raw("/audio/translations", request)
.await
}

/// Generates audio from the input text.
pub async fn speech(
&self,
request: CreateSpeechRequest,
) -> Result<CreateSpeechResponse, OpenAIError> {
let bytes = self.client.post_raw("/audio/speech", request).await?;

Ok(CreateSpeechResponse { bytes })
/// APIs in Translation group.
pub fn translation(&self) -> Translations<'_, C> {
Translations::new(self.client)
}
}
75 changes: 74 additions & 1 deletion async-openai/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,79 @@ impl<C: Config> Client<C> {
self.execute(request_maker).await
}

pub(crate) async fn post_form_stream<O, F>(
&self,
path: &str,
form: F,
) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
where
F: Clone,
Form: AsyncTryFrom<F, Error = OpenAIError>,
O: DeserializeOwned + std::marker::Send + 'static,
{
// Build and execute request manually since multipart::Form is not Clone
// and .eventsource() requires cloneability
let response = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
.headers(self.config.headers())
.send()
.await
.map_err(OpenAIError::Reqwest)?;

// Check for error status
if !response.status().is_success() {
return Err(read_response(response).await.unwrap_err());
}

// Convert response body to EventSource stream
let stream = response
.bytes_stream()
.map(|result| result.map_err(std::io::Error::other));
let event_stream = eventsource_stream::EventStream::new(stream);

// Convert EventSource stream to our expected format
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

tokio::spawn(async move {
use futures::StreamExt;
let mut event_stream = std::pin::pin!(event_stream);

while let Some(event_result) = event_stream.next().await {
match event_result {
Err(e) => {
if let Err(_e) = tx.send(Err(OpenAIError::StreamError(Box::new(
StreamError::EventStream(e.to_string()),
)))) {
break;
}
}
Ok(event) => {
// eventsource_stream::Event is a struct with data field
if event.data == "[DONE]" {
break;
}

let response = match serde_json::from_str::<O>(&event.data) {
Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
Ok(output) => Ok(output),
};

if let Err(_e) = tx.send(response) {
break;
}
}
}
}
});

Ok(Box::pin(
tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
))
}

/// Execute a HTTP request and retry on rate limit
///
/// request_maker serves one purpose: to be able to create request again
Expand Down Expand Up @@ -524,7 +597,7 @@ async fn map_stream_error(value: EventSourceError) -> OpenAIError {
"Unreachable because read_response returns err when status_code {status_code} is invalid"
))
}
_ => OpenAIError::StreamError(StreamError::ReqwestEventSource(value)),
_ => OpenAIError::StreamError(Box::new(StreamError::ReqwestEventSource(value))),
}
}

Expand Down
5 changes: 4 additions & 1 deletion async-openai/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub enum OpenAIError {
FileReadError(String),
/// Error on SSE streaming
#[error("stream failed: {0}")]
StreamError(StreamError),
StreamError(Box<StreamError>),
/// Error from client side validation
/// or when builder fails to build request before making API call
#[error("invalid args: {0}")]
Expand All @@ -36,6 +36,9 @@ pub enum StreamError {
/// Error when a stream event does not match one of the expected values
#[error("Unknown event: {0:#?}")]
UnknownEvent(eventsource_stream::Event),
/// Error from eventsource_stream when parsing SSE
#[error("EventStream error: {0}")]
EventStream(String),
}

/// OpenAI API returns error object on failure
Expand Down
6 changes: 6 additions & 0 deletions async-openai/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,12 @@ mod project_users;
mod projects;
mod responses;
mod runs;
mod speech;
mod steps;
mod threads;
pub mod traits;
mod transcriptions;
mod translations;
pub mod types;
mod uploads;
mod users;
Expand Down Expand Up @@ -207,8 +210,11 @@ pub use project_users::ProjectUsers;
pub use projects::Projects;
pub use responses::Responses;
pub use runs::Runs;
pub use speech::Speech;
pub use steps::Steps;
pub use threads::Threads;
pub use transcriptions::Transcriptions;
pub use translations::Translations;
pub use uploads::Uploads;
pub use users::Users;
pub use vector_store_file_batches::VectorStoreFileBatches;
Expand Down
54 changes: 54 additions & 0 deletions async-openai/src/speech.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use crate::{
config::Config,
error::OpenAIError,
types::audio::{CreateSpeechRequest, CreateSpeechResponse, SpeechResponseStream},
Client,
};

pub struct Speech<'c, C: Config> {
client: &'c Client<C>,
}

impl<'c, C: Config> Speech<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}

/// Generates audio from the input text.
pub async fn create(
&self,
request: CreateSpeechRequest,
) -> Result<CreateSpeechResponse, OpenAIError> {
let bytes = self.client.post_raw("/audio/speech", request).await?;

Ok(CreateSpeechResponse { bytes })
}

/// Generates audio from the input text in SSE stream format.
#[crate::byot(
T0 = serde::Serialize,
R = serde::de::DeserializeOwned,
stream = "true",
where_clause = "R: std::marker::Send + 'static"
)]
#[allow(unused_mut)]
pub async fn create_stream(
&self,
mut request: CreateSpeechRequest,
) -> Result<SpeechResponseStream, OpenAIError> {
#[cfg(not(feature = "byot"))]
{
use crate::types::audio::StreamFormat;
if let Some(stream_format) = request.stream_format {
if stream_format != StreamFormat::SSE {
return Err(OpenAIError::InvalidArgument(
"When stream_format is not SSE, use Audio::speech".into(),
));
}
}

request.stream_format = Some(StreamFormat::SSE);
}
Ok(self.client.post_stream("/audio/speech", request).await)
}
}
Loading
Loading