Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 22 additions & 31 deletions async-openai/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,11 @@ impl<C: Config> Client<C> {
&self,
request_factory: HttpRequestFactory,
) -> Result<Response, OpenAIError> {
self.executor.execute(request_factory).await
let response = self.executor.execute(request_factory).await?;
if !response.status().is_success() {
return Err(read_error_response(response).await);
}
Ok(response)
}

async fn execute_stream<O>(
Expand Down Expand Up @@ -724,37 +728,41 @@ impl<C: Config> Client<C> {
}

async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
let status = response.status();
let headers = response.headers().clone();
let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
Ok((bytes, headers))
}

async fn read_error_response(response: Response) -> OpenAIError {
let status = response.status();
let bytes = match response.bytes().await {
Ok(b) => b,
Err(e) => return OpenAIError::Reqwest(e),
};

if status.is_server_error() {
// OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
let message: String = String::from_utf8_lossy(&bytes).into_owned();
tracing::warn!("Server error: {status} - {message}");
return Err(OpenAIError::ApiError(ApiErrorResponse {
return OpenAIError::ApiError(ApiErrorResponse {
status_code: status,
api_error: ApiError {
message,
r#type: None,
param: None,
code: None,
},
}));
});
}

// Deserialize response body from either error object or actual response object
if !status.is_success() {
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;

return Err(OpenAIError::ApiError(ApiErrorResponse {
// Deserialize response body from the error object
match serde_json::from_slice::<WrappedError>(bytes.as_ref()) {
Ok(wrapped) => OpenAIError::ApiError(ApiErrorResponse {
status_code: status,
api_error: wrapped_error.error,
}));
api_error: wrapped.error,
}),
Err(e) => map_deserialization_error(e, bytes.as_ref()),
}

Ok((bytes, headers))
}

/// Request which responds with SSE.
Expand All @@ -778,17 +786,6 @@ pub(crate) async fn stream_mapped_raw_events<O>(
where
O: DeserializeOwned + 'static,
{
if !response.status().is_success() {
return Box::pin(futures::stream::once(async move {
match read_response(response).await {
Ok(_) => Err(OpenAIError::InvalidArgument(
"stream request failed without an error body".into(),
)),
Err(error) => Err(error),
}
}));
}

let byte_stream = response
.bytes_stream()
.map(|result| result.map_err(std::io::Error::other));
Expand Down Expand Up @@ -837,12 +834,6 @@ where
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

tokio::spawn(async move {
if !response.status().is_success() {
if let Err(e) = read_response(response).await {
let _ = tx.send(Err(e));
}
return;
}
let byte_stream = response
.bytes_stream()
.map(|r| r.map_err(std::io::Error::other));
Expand Down
Loading