Skip to content

Commit 4d3dc52

Browse files
committed
add chatkit api
1 parent 459a862 commit 4d3dc52

8 files changed

Lines changed: 821 additions & 3 deletions

File tree

async-openai/src/chatkit.rs

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
use serde::Serialize;
2+
3+
use crate::{
4+
config::Config,
5+
error::OpenAIError,
6+
types::chatkit::{
7+
ChatSessionResource, CreateChatSessionBody, DeletedThreadResource, ThreadItemListResource,
8+
ThreadListResource, ThreadResource,
9+
},
10+
Client,
11+
};
12+
13+
/// ChatKit API for managing sessions and threads.
14+
///
15+
/// Related guide: [ChatKit](https://platform.openai.com/docs/api-reference/chatkit)
16+
pub struct Chatkit<'c, C: Config> {
17+
client: &'c Client<C>,
18+
}
19+
20+
impl<'c, C: Config> Chatkit<'c, C> {
21+
pub fn new(client: &'c Client<C>) -> Self {
22+
Self { client }
23+
}
24+
25+
/// Access sessions API.
26+
pub fn sessions(&self) -> ChatkitSessions<'_, C> {
27+
ChatkitSessions::new(self.client)
28+
}
29+
30+
/// Access threads API.
31+
pub fn threads(&self) -> ChatkitThreads<'_, C> {
32+
ChatkitThreads::new(self.client)
33+
}
34+
}
35+
36+
/// ChatKit sessions API.
37+
pub struct ChatkitSessions<'c, C: Config> {
38+
client: &'c Client<C>,
39+
}
40+
41+
impl<'c, C: Config> ChatkitSessions<'c, C> {
42+
pub fn new(client: &'c Client<C>) -> Self {
43+
Self { client }
44+
}
45+
46+
/// Create a ChatKit session.
47+
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
48+
pub async fn create(
49+
&self,
50+
request: CreateChatSessionBody,
51+
) -> Result<ChatSessionResource, OpenAIError> {
52+
self.client.post("/chatkit/sessions", request).await
53+
}
54+
55+
/// Cancel a ChatKit session.
56+
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
57+
pub async fn cancel(&self, session_id: &str) -> Result<ChatSessionResource, OpenAIError> {
58+
self.client
59+
.post(
60+
&format!("/chatkit/sessions/{session_id}/cancel"),
61+
serde_json::json!({}),
62+
)
63+
.await
64+
}
65+
}
66+
67+
/// ChatKit threads API.
68+
pub struct ChatkitThreads<'c, C: Config> {
69+
client: &'c Client<C>,
70+
}
71+
72+
impl<'c, C: Config> ChatkitThreads<'c, C> {
73+
pub fn new(client: &'c Client<C>) -> Self {
74+
Self { client }
75+
}
76+
77+
/// List ChatKit threads.
78+
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
79+
pub async fn list<Q>(&self, query: &Q) -> Result<ThreadListResource, OpenAIError>
80+
where
81+
Q: Serialize + ?Sized,
82+
{
83+
self.client.get_with_query("/chatkit/threads", &query).await
84+
}
85+
86+
/// Retrieve a ChatKit thread.
87+
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
88+
pub async fn retrieve(&self, thread_id: &str) -> Result<ThreadResource, OpenAIError> {
89+
self.client
90+
.get(&format!("/chatkit/threads/{thread_id}"))
91+
.await
92+
}
93+
94+
/// Delete a ChatKit thread.
95+
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
96+
pub async fn delete(&self, thread_id: &str) -> Result<DeletedThreadResource, OpenAIError> {
97+
self.client
98+
.delete(&format!("/chatkit/threads/{thread_id}"))
99+
.await
100+
}
101+
102+
/// List ChatKit thread items.
103+
#[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)]
104+
pub async fn list_items<Q>(
105+
&self,
106+
thread_id: &str,
107+
query: &Q,
108+
) -> Result<ThreadItemListResource, OpenAIError>
109+
where
110+
Q: Serialize + ?Sized,
111+
{
112+
self.client
113+
.get_with_query(&format!("/chatkit/threads/{thread_id}/items"), &query)
114+
.await
115+
}
116+
}

async-openai/src/client.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, Request
77
use serde::{de::DeserializeOwned, Serialize};
88

99
use crate::{
10+
chatkit::Chatkit,
1011
config::{Config, OpenAIConfig},
1112
error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
1213
file::Files,
@@ -188,6 +189,11 @@ impl<C: Config> Client<C> {
188189
Evals::new(self)
189190
}
190191

192+
/// To call [Chatkit] group related APIs using this client.
193+
pub fn chatkit(&self) -> Chatkit<'_, C> {
194+
Chatkit::new(self)
195+
}
196+
191197
pub fn config(&self) -> &C {
192198
&self.config
193199
}

async-openai/src/config.rs

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use reqwest::header::{HeaderMap, AUTHORIZATION};
33
use secrecy::{ExposeSecret, SecretString};
44
use serde::Deserialize;
55

6+
use crate::error::OpenAIError;
7+
68
/// Default v1 API base url
79
pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
810
/// Organization header
@@ -59,6 +61,8 @@ pub struct OpenAIConfig {
5961
api_key: SecretString,
6062
org_id: String,
6163
project_id: String,
64+
#[serde(skip)]
65+
custom_headers: HeaderMap,
6266
}
6367

6468
impl Default for OpenAIConfig {
@@ -70,6 +74,7 @@ impl Default for OpenAIConfig {
7074
.into(),
7175
org_id: Default::default(),
7276
project_id: Default::default(),
77+
custom_headers: HeaderMap::new(),
7378
}
7479
}
7580
}
@@ -104,6 +109,21 @@ impl OpenAIConfig {
104109
self
105110
}
106111

112+
/// Add a custom header that will be included in all requests.
113+
/// Headers are merged with existing headers, with custom headers taking precedence.
114+
pub fn with_header<K, V>(mut self, key: K, value: V) -> Result<Self, OpenAIError>
115+
where
116+
K: reqwest::header::IntoHeaderName,
117+
V: TryInto<reqwest::header::HeaderValue>,
118+
V::Error: Into<reqwest::header::InvalidHeaderValue>,
119+
{
120+
let header_value = value.try_into().map_err(|e| {
121+
OpenAIError::InvalidArgument(format!("Invalid header value: {}", e.into()))
122+
})?;
123+
self.custom_headers.insert(key, header_value);
124+
Ok(self)
125+
}
126+
107127
pub fn org_id(&self) -> &str {
108128
&self.org_id
109129
}
@@ -134,9 +154,10 @@ impl Config for OpenAIConfig {
134154
.unwrap(),
135155
);
136156

137-
// hack for Assistants APIs
138-
// Calls to the Assistants API require that you pass a Beta header
139-
// headers.insert(OPENAI_BETA_HEADER, "assistants=v2".parse().unwrap());
157+
// Merge custom headers, with custom headers taking precedence
158+
for (key, value) in self.custom_headers.iter() {
159+
headers.insert(key, value.clone());
160+
}
140161

141162
headers
142163
}

async-openai/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ mod audio;
145145
mod audit_logs;
146146
mod batches;
147147
mod chat;
148+
mod chatkit;
148149
mod client;
149150
mod completion;
150151
pub mod config;
@@ -193,6 +194,7 @@ pub use audio::Audio;
193194
pub use audit_logs::AuditLogs;
194195
pub use batches::Batches;
195196
pub use chat::Chat;
197+
pub use chatkit::Chatkit;
196198
pub use client::Client;
197199
pub use completion::Completions;
198200
pub use container_files::ContainerFiles;
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
mod session;
2+
mod thread;
3+
4+
pub use session::*;
5+
pub use thread::*;

0 commit comments

Comments
 (0)