Skip to content

Commit 8ba7572

Browse files
stkimmerjrebelo
andauthored
add multipart to create_file (#2)
* Add string enumeration * Fix string formating * Update version number * add multipart to create_file * Add special handling for multipart form data in request body * Update cargo lock * Add multipart form * Add error display implementation * Add some multipart implementations * Fix clippy and path to license --------- Co-authored-by: Joao Rebelo <[email protected]>
1 parent d05165f commit 8ba7572

File tree

10 files changed

+605
-71
lines changed

10 files changed

+605
-71
lines changed

Cargo.lock

Lines changed: 19 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
File renamed without changes.

openai_client/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "conversa_openai_client"
3-
version = "0.1.0"
3+
version = "0.2.0"
44
edition = "2024"
55
authors = ["Joao Rebelo <[email protected]>"]
66
license = "Apache-2.0"
@@ -11,7 +11,7 @@ categories = ["api-bindings", "web-programming::http-client"]
1111
description = "A native Rust client for the complete OpenAI REST API."
1212

1313
[dependencies]
14-
reqwest = "0.12.22"
14+
reqwest = { version = "0.12.22", features = ["multipart"] }
1515
serde = { version = "1.0.219", features = ["derive"] }
1616
serde_json = "1.0.140"
1717
tokio = { version = "1.46.0", features = ["rt", "macros"] }

openai_client/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ A native Rust client for the complete OpenAI REST API.
44

55
[![Crates.io](https://img.shields.io/crates/v/conversa_openai_client.svg)](https://crates.io/crates/conversa_openai_client)
66
[![Docs.rs](https://docs.rs/conversa_openai_client/badge.svg)](https://docs.rs/conversa_openai_client)
7-
[![License](https://img.shields.io/crates/l/conversa_openai_client.svg)](./LICENSE)
7+
[![License](https://img.shields.io/crates/l/conversa_openai_client.svg)](../LICENSE)
88

99
---
1010

openai_client/build.rs

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,14 @@ fn parse_object_type(name: &str, schema: &Yaml, output_file: &mut File) {
276276
.is_some()
277277
{
278278
generate_inner_object_name(name, &field_name)
279+
} else if property_hash.get(&Yaml::String("format".to_string()))
280+
== Some(&Yaml::String("binary".to_string()))
281+
{
282+
if field_name == "file" {
283+
"crate::multipart::File".to_string()
284+
} else {
285+
"Vec<u8>".to_string()
286+
}
279287
} else {
280288
"String".to_string()
281289
}
@@ -482,7 +490,6 @@ fn parse_oneof_type(name: &str, schema: &Yaml, output_file: &mut File) {
482490
writeln!(output_file, "#[serde(untagged)]").unwrap();
483491
writeln!(output_file, "pub enum {} {{", name).unwrap();
484492

485-
let mut string_created = false;
486493
for (index, one_of_variant) in one_of_list.iter().enumerate() {
487494
let one_of_variant_hash = one_of_variant.as_hash().unwrap();
488495
if let Some(doc) = one_of_variant_hash
@@ -510,9 +517,25 @@ fn parse_oneof_type(name: &str, schema: &Yaml, output_file: &mut File) {
510517
// Some variants have two String types to account for enumerations but for
511518
// our type this is not necessary because all String representations are
512519
// equal
513-
if !string_created {
520+
if let Some(Yaml::Array(enum_list)) =
521+
one_of_variant_hash.get(&Yaml::String("enum".to_string()))
522+
{
523+
for string_variant in enum_list {
524+
writeln!(
525+
output_file,
526+
"\t#[serde(rename=\"{}\")]",
527+
string_variant.as_str().unwrap()
528+
)
529+
.unwrap();
530+
writeln!(
531+
output_file,
532+
"\t{},",
533+
str_to_camel_case(string_variant.as_str().unwrap())
534+
)
535+
.unwrap();
536+
}
537+
} else {
514538
writeln!(output_file, "\tString(String),").unwrap();
515-
string_created = true;
516539
}
517540
}
518541
"integer" => {
@@ -1139,7 +1162,7 @@ fn parse_endpoint_path(path_schema: &Yaml, client_output_file: &mut File) {
11391162
{
11401163
"String".to_string()
11411164
} else {
1142-
todo!("{:?}", response_schema_hash)
1165+
unimplemented!("{:?}", response_schema_hash)
11431166
}
11441167
} else {
11451168
str_to_camel_case(&format!("{operation_name}_response"))
@@ -1171,7 +1194,7 @@ fn parse_endpoint_path(path_schema: &Yaml, client_output_file: &mut File) {
11711194
{
11721195
str_to_camel_case(&format!("{operation_name}_response"))
11731196
} else {
1174-
todo!("{:?}", response_schema_hash)
1197+
unimplemented!("{:?}", response_schema_hash)
11751198
}
11761199
} else {
11771200
str_to_camel_case(&format!("{operation_name}_response"))
@@ -1234,18 +1257,34 @@ fn parse_endpoint_path(path_schema: &Yaml, client_output_file: &mut File) {
12341257
{
12351258
let request_body_is_required =
12361259
request_body_hash["required"].as_bool().unwrap_or(false);
1237-
if request_body_is_required {
1238-
writeln!(
1260+
1261+
let request_body_content = request_body_hash["content"].as_hash().unwrap();
1262+
debug_assert!(request_body_content.len() == 1);
1263+
// TODO: It requires different handling depending on the type of request body (application/json or multipart/form-data)
1264+
let request_body_content_type =
1265+
request_body_content.front().unwrap().0.as_str().unwrap();
1266+
if request_body_content_type == "application/json" {
1267+
if request_body_is_required {
1268+
writeln!(
1269+
client_output_file,
1270+
"\t\trequest = request.body(serde_json::to_string(&request_body)?);",
1271+
)
1272+
.unwrap();
1273+
} else {
1274+
writeln!(
12391275
client_output_file,
1240-
"\t\trequest = request.body(serde_json::to_string(&request_body)?);",
1276+
"\t\tif let Some(b) = request_body {{\n\t\t\trequest = request.body(serde_json::to_string(&b)?);\n\t\t}}",
12411277
)
12421278
.unwrap();
1243-
} else {
1279+
}
1280+
} else if request_body_content_type == "multipart/form-data" {
12441281
writeln!(
12451282
client_output_file,
1246-
"\t\tif let Some(b) = request_body {{\n\t\t\trequest = request.body(serde_json::to_string(&b)?);\n\t\t}}",
1283+
"\t\trequest = request.multipart(request_body.into_multipart_form());",
12471284
)
12481285
.unwrap();
1286+
} else {
1287+
unimplemented!("Request body type: {}", request_body_content_type);
12491288
}
12501289
}
12511290

openai_client/openapi.documented.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41804,7 +41804,7 @@ components:
4180441804
- type: integer
4180541805
- type: string
4180641806
enum:
41807-
- inf
41807+
- "inf"
4180841808
x-stainless-const: true
4180941809
description: |
4181041810
Maximum number of output tokens for a single assistant response,
@@ -41891,7 +41891,7 @@ components:
4189141891
- type: integer
4189241892
- type: string
4189341893
enum:
41894-
- inf
41894+
- "inf"
4189541895
x-stainless-const: true
4189641896
description: |
4189741897
Maximum number of output tokens for a single assistant response,
@@ -43903,7 +43903,7 @@ components:
4390343903
- type: integer
4390443904
- type: string
4390543905
enum:
43906-
- inf
43906+
- "inf"
4390743907
x-stainless-const: true
4390843908
description: |
4390943909
Maximum number of output tokens for a single assistant response,
@@ -44187,7 +44187,7 @@ components:
4418744187
- type: integer
4418844188
- type: string
4418944189
enum:
44190-
- inf
44190+
- "inf"
4419144191
x-stainless-const: true
4419244192
description: |
4419344193
Maximum number of output tokens for a single assistant response,
@@ -44407,7 +44407,7 @@ components:
4440744407
- type: integer
4440844408
- type: string
4440944409
enum:
44410-
- inf
44410+
- "inf"
4441144411
x-stainless-const: true
4441244412
description: |
4441344413
Maximum number of output tokens for a single assistant response,

openai_client/src/client.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ impl OpenAIClient {
280280
let address = format!("{}/audio/transcriptions", self.base_address);
281281
let mut request = self.client.post(&address);
282282
request = request.bearer_auth(&self.api_key);
283-
request = request.body(serde_json::to_string(&request_body)?);
283+
request = request.multipart(request_body.into_multipart_form());
284284
let result = request.send().await?;
285285
let status_code = result.status().as_u16();
286286
let _content_type = result.headers()[reqwest::header::CONTENT_TYPE].to_str()?.to_string();
@@ -306,7 +306,7 @@ impl OpenAIClient {
306306
let address = format!("{}/audio/translations", self.base_address);
307307
let mut request = self.client.post(&address);
308308
request = request.bearer_auth(&self.api_key);
309-
request = request.body(serde_json::to_string(&request_body)?);
309+
request = request.multipart(request_body.into_multipart_form());
310310
let result = request.send().await?;
311311
let status_code = result.status().as_u16();
312312
let _content_type = result.headers()[reqwest::header::CONTENT_TYPE].to_str()?.to_string();
@@ -716,7 +716,7 @@ You can send either a multipart/form-data request with the raw file content, or
716716
let address = format!("{}/containers/{container_id}/files", self.base_address);
717717
let mut request = self.client.post(&address);
718718
request = request.bearer_auth(&self.api_key);
719-
request = request.body(serde_json::to_string(&request_body)?);
719+
request = request.multipart(request_body.into_multipart_form());
720720
let result = request.send().await?;
721721
let status_code = result.status().as_u16();
722722
let _content_type = result.headers()[reqwest::header::CONTENT_TYPE].to_str()?.to_string();
@@ -1187,7 +1187,7 @@ Please [contact us](https://help.openai.com/) if you need to increase these stor
11871187
let address = format!("{}/files", self.base_address);
11881188
let mut request = self.client.post(&address);
11891189
request = request.bearer_auth(&self.api_key);
1190-
request = request.body(serde_json::to_string(&request_body)?);
1190+
request = request.multipart(request_body.into_multipart_form());
11911191
let result = request.send().await?;
11921192
let status_code = result.status().as_u16();
11931193
let _content_type = result.headers()[reqwest::header::CONTENT_TYPE].to_str()?.to_string();
@@ -1594,7 +1594,7 @@ Response includes details of the enqueued job including job status and the name
15941594
let address = format!("{}/images/edits", self.base_address);
15951595
let mut request = self.client.post(&address);
15961596
request = request.bearer_auth(&self.api_key);
1597-
request = request.body(serde_json::to_string(&request_body)?);
1597+
request = request.multipart(request_body.into_multipart_form());
15981598
let result = request.send().await?;
15991599
let status_code = result.status().as_u16();
16001600
let _content_type = result.headers()[reqwest::header::CONTENT_TYPE].to_str()?.to_string();
@@ -1638,7 +1638,7 @@ Response includes details of the enqueued job including job status and the name
16381638
let address = format!("{}/images/variations", self.base_address);
16391639
let mut request = self.client.post(&address);
16401640
request = request.bearer_auth(&self.api_key);
1641-
request = request.body(serde_json::to_string(&request_body)?);
1641+
request = request.multipart(request_body.into_multipart_form());
16421642
let result = request.send().await?;
16431643
let status_code = result.status().as_u16();
16441644
let _content_type = result.headers()[reqwest::header::CONTENT_TYPE].to_str()?.to_string();
@@ -3912,7 +3912,7 @@ It is possible to add multiple Parts in parallel. You can decide the intended or
39123912
let address = format!("{}/uploads/{upload_id}/parts", self.base_address);
39133913
let mut request = self.client.post(&address);
39143914
request = request.bearer_auth(&self.api_key);
3915-
request = request.body(serde_json::to_string(&request_body)?);
3915+
request = request.multipart(request_body.into_multipart_form());
39163916
let result = request.send().await?;
39173917
let status_code = result.status().as_u16();
39183918
let _content_type = result.headers()[reqwest::header::CONTENT_TYPE].to_str()?.to_string();

openai_client/src/lib.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#![allow(clippy::large_enum_variant)]
33

44
pub mod client;
5+
mod multipart;
56
pub mod types;
67

78
use std::string::FromUtf8Error;
@@ -12,6 +13,7 @@ use reqwest::{Client, header::ToStrError};
1213
pub enum ConversaError {
1314
ClientError(String),
1415
InvalidData(String),
16+
IoError(String),
1517
UnexpectedStatusCode { code: u16, response: String },
1618
UnexpectedContentType(String),
1719
ErrorResponse(crate::types::ErrorResponse),
@@ -42,6 +44,31 @@ impl From<FromUtf8Error> for ConversaError {
4244
}
4345
}
4446

47+
impl From<std::io::Error> for ConversaError {
48+
fn from(value: std::io::Error) -> Self {
49+
ConversaError::IoError(value.to_string())
50+
}
51+
}
52+
53+
impl std::fmt::Display for ConversaError {
54+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55+
match self {
56+
ConversaError::ClientError(msg) => write!(f, "Client error: {}", msg),
57+
ConversaError::InvalidData(msg) => write!(f, "Invalid data: {}", msg),
58+
ConversaError::UnexpectedStatusCode { code, response } => {
59+
write!(f, "Unexpected status code {}: {}", code, response)
60+
}
61+
ConversaError::IoError(msg) => write!(f, "std::io error: {}", msg),
62+
ConversaError::UnexpectedContentType(content_type) => {
63+
write!(f, "Unexpected content type: {}", content_type)
64+
}
65+
ConversaError::ErrorResponse(err) => write!(f, "Error response: {:?}", err),
66+
ConversaError::Error(err) => write!(f, "Error: {:?}", err),
67+
}
68+
}
69+
}
70+
impl std::error::Error for ConversaError {}
71+
4572
pub type ConversaResult<T> = Result<T, ConversaError>;
4673

4774
pub struct OpenAIClientBuilder {

0 commit comments

Comments
 (0)