Skip to content

Commit 9d62d29

Browse files
authored
feat(core): add DeepSeek model listing api (0xPlaygrounds#1672)
* feat(deepseek): implement ModelLister for DeepSeek API and add model listing functionality * feat(deepseek): add DeepSeek client initialization to model listing tests
1 parent c288fd4 commit 9d62d29

2 files changed

Lines changed: 296 additions & 4 deletions

File tree

rig/rig-core/src/client/mod.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ mod wasm_model_listing_compile_checks {
765765
use super::{ModelListingClient, Nothing};
766766
use crate::{
767767
http_client::{self, HttpClientExt, LazyBody, MultipartForm, Request, Response},
768-
providers::{anthropic, mistral, ollama, openai, openrouter},
768+
providers::{anthropic, deepseek, mistral, ollama, openai, openrouter},
769769
wasm_compat::WasmCompatSend,
770770
};
771771
use bytes::Bytes;
@@ -850,6 +850,12 @@ mod wasm_model_listing_compile_checks {
850850
.http_client(WasmOnlyHttpClient::default())
851851
.build()
852852
.map(assert_model_listing_client);
853+
854+
let _ = deepseek::Client::builder()
855+
.api_key("dummy-key")
856+
.http_client(WasmOnlyHttpClient::default())
857+
.build()
858+
.map(assert_model_listing_client);
853859
}
854860

855861
#[allow(dead_code)]

rig/rig-core/src/providers/deepseek.rs

Lines changed: 289 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,21 @@ use http::Request;
1414
use tracing::{Instrument, Level, enabled, info_span};
1515

1616
use crate::client::{
17-
self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
18-
ProviderClient,
17+
self, BearerAuth, Capabilities, Capable, DebugExt, ModelLister, Nothing, Provider,
18+
ProviderBuilder, ProviderClient,
1919
};
2020
use crate::completion::GetTokenUsage;
2121
use crate::http_client::{self, HttpClientExt};
2222
use crate::message::{Document, DocumentSourceKind};
23+
use crate::model::{Model, ModelList, ModelListingError};
2324
use crate::providers::internal::openai_chat_completions_compatible::{
2425
self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
2526
};
2627
use crate::{
2728
OneOrMany,
2829
completion::{self, CompletionError, CompletionRequest},
2930
json_utils, message,
31+
wasm_compat::{WasmCompatSend, WasmCompatSync},
3032
};
3133
use serde::{Deserialize, Serialize};
3234

@@ -53,7 +55,7 @@ impl<H> Capabilities<H> for DeepSeekExt {
5355
type Completion = Capable<CompletionModel<H>>;
5456
type Embeddings = Nothing;
5557
type Transcription = Nothing;
56-
type ModelListing = Nothing;
58+
type ModelListing = Capable<DeepSeekModelLister<H>>;
5759
#[cfg(feature = "image")]
5860
type ImageGeneration = Nothing;
5961
#[cfg(feature = "audio")]
@@ -791,6 +793,82 @@ where
791793
.await
792794
}
793795

796+
#[derive(Debug, Deserialize)]
797+
struct ListModelsResponse {
798+
data: Vec<ListModelEntry>,
799+
}
800+
801+
#[derive(Debug, Deserialize)]
802+
struct ListModelEntry {
803+
id: String,
804+
owned_by: String,
805+
}
806+
807+
impl From<ListModelEntry> for Model {
808+
fn from(value: ListModelEntry) -> Self {
809+
let mut model = Model::from_id(value.id);
810+
model.owned_by = Some(value.owned_by);
811+
model
812+
}
813+
}
814+
815+
/// [`ModelLister`] implementation for the DeepSeek API (`GET /models`).
816+
#[derive(Clone)]
817+
pub struct DeepSeekModelLister<H = reqwest::Client> {
818+
client: Client<H>,
819+
}
820+
821+
impl<H> ModelLister<H> for DeepSeekModelLister<H>
822+
where
823+
H: HttpClientExt + WasmCompatSend + WasmCompatSync + 'static,
824+
{
825+
type Client = Client<H>;
826+
827+
fn new(client: Self::Client) -> Self {
828+
Self { client }
829+
}
830+
831+
async fn list_all(&self) -> Result<ModelList, ModelListingError> {
832+
let path = "/models";
833+
let req = self.client.get(path)?.body(http_client::NoBody)?;
834+
let response = self
835+
.client
836+
.send::<_, Vec<u8>>(req)
837+
.await
838+
.map_err(|error| match error {
839+
http_client::Error::InvalidStatusCodeWithMessage(status, message) => {
840+
ModelListingError::api_error_with_context(
841+
"DeepSeek",
842+
path,
843+
status.as_u16(),
844+
message.as_bytes(),
845+
)
846+
}
847+
other => ModelListingError::from(other),
848+
})?;
849+
850+
if !response.status().is_success() {
851+
let status_code = response.status().as_u16();
852+
let body = response.into_body().await?;
853+
return Err(ModelListingError::api_error_with_context(
854+
"DeepSeek",
855+
path,
856+
status_code,
857+
&body,
858+
));
859+
}
860+
861+
let body = response.into_body().await?;
862+
let api_resp: ListModelsResponse = serde_json::from_slice(&body).map_err(|error| {
863+
ModelListingError::parse_error_with_context("DeepSeek", path, &error, &body)
864+
})?;
865+
866+
let models = api_resp.data.into_iter().map(Model::from).collect();
867+
868+
Ok(ModelList::new(models))
869+
}
870+
}
871+
794872
// ================================================================
795873
// DeepSeek Completion API
796874
// ================================================================
@@ -813,6 +891,11 @@ pub const DEEPSEEK_V4_PRO: &str = "deepseek-v4-pro";
813891
#[cfg(test)]
814892
mod tests {
815893
use super::*;
894+
use crate::client::ModelListingClient;
895+
use crate::http_client::{LazyBody, MultipartForm, Request as HttpRequest, Response};
896+
use bytes::Bytes;
897+
use std::future::{self, Future};
898+
use std::sync::{Arc, Mutex};
816899

817900
#[test]
818901
fn test_deserialize_vec_choice() {
@@ -1070,4 +1153,207 @@ mod tests {
10701153
.build()
10711154
.expect("Client::builder() failed");
10721155
}
1156+
1157+
#[test]
1158+
fn test_deserialize_list_models_response() {
1159+
let data = r#"{
1160+
"object": "list",
1161+
"data": [
1162+
{
1163+
"id": "deepseek-v4-flash",
1164+
"object": "model",
1165+
"owned_by": "deepseek"
1166+
},
1167+
{
1168+
"id": "deepseek-v4-pro",
1169+
"object": "model",
1170+
"owned_by": "deepseek"
1171+
}
1172+
]
1173+
}"#;
1174+
1175+
let response: ListModelsResponse = serde_json::from_str(data).unwrap();
1176+
1177+
assert_eq!(response.data.len(), 2);
1178+
assert_eq!(response.data[0].id, "deepseek-v4-flash");
1179+
assert_eq!(response.data[0].owned_by, "deepseek");
1180+
}
1181+
1182+
#[derive(Debug, Clone, PartialEq, Eq)]
1183+
struct CapturedRequest {
1184+
uri: String,
1185+
}
1186+
1187+
#[derive(Clone)]
1188+
enum MockResponse {
1189+
Success(Bytes),
1190+
Error(http::StatusCode, String),
1191+
}
1192+
1193+
impl Default for MockResponse {
1194+
fn default() -> Self {
1195+
Self::Success(Bytes::new())
1196+
}
1197+
}
1198+
1199+
#[derive(Clone, Default)]
1200+
struct RecordingHttpClient {
1201+
requests: Arc<Mutex<Vec<CapturedRequest>>>,
1202+
response: Arc<Mutex<MockResponse>>,
1203+
}
1204+
1205+
impl RecordingHttpClient {
1206+
fn new(response_body: impl Into<Bytes>) -> Self {
1207+
Self {
1208+
requests: Arc::new(Mutex::new(Vec::new())),
1209+
response: Arc::new(Mutex::new(MockResponse::Success(response_body.into()))),
1210+
}
1211+
}
1212+
1213+
fn with_error(status: http::StatusCode, message: impl Into<String>) -> Self {
1214+
Self {
1215+
requests: Arc::new(Mutex::new(Vec::new())),
1216+
response: Arc::new(Mutex::new(MockResponse::Error(status, message.into()))),
1217+
}
1218+
}
1219+
1220+
fn requests(&self) -> Vec<CapturedRequest> {
1221+
self.requests.lock().expect("requests lock").clone()
1222+
}
1223+
}
1224+
1225+
impl HttpClientExt for RecordingHttpClient {
1226+
fn send<T, U>(
1227+
&self,
1228+
req: HttpRequest<T>,
1229+
) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
1230+
where
1231+
T: Into<Bytes> + WasmCompatSend,
1232+
U: From<Bytes> + WasmCompatSend + 'static,
1233+
{
1234+
let requests = Arc::clone(&self.requests);
1235+
let response = self.response.lock().expect("response lock").clone();
1236+
let (parts, _body) = req.into_parts();
1237+
1238+
requests
1239+
.lock()
1240+
.expect("requests lock")
1241+
.push(CapturedRequest {
1242+
uri: parts.uri.to_string(),
1243+
});
1244+
1245+
async move {
1246+
let response_body = match response {
1247+
MockResponse::Success(response_body) => response_body,
1248+
MockResponse::Error(status, message) => {
1249+
return Err(http_client::Error::InvalidStatusCodeWithMessage(
1250+
status, message,
1251+
));
1252+
}
1253+
};
1254+
let body: LazyBody<U> = Box::pin(async move { Ok(U::from(response_body)) });
1255+
Response::builder()
1256+
.status(http::StatusCode::OK)
1257+
.body(body)
1258+
.map_err(http_client::Error::Protocol)
1259+
}
1260+
}
1261+
1262+
fn send_multipart<U>(
1263+
&self,
1264+
_req: HttpRequest<MultipartForm>,
1265+
) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
1266+
where
1267+
U: From<Bytes> + WasmCompatSend + 'static,
1268+
{
1269+
future::ready(Err(http_client::Error::InvalidStatusCode(
1270+
http::StatusCode::NOT_IMPLEMENTED,
1271+
)))
1272+
}
1273+
1274+
fn send_streaming<T>(
1275+
&self,
1276+
_req: HttpRequest<T>,
1277+
) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
1278+
where
1279+
T: Into<Bytes> + WasmCompatSend,
1280+
{
1281+
future::ready(Err(http_client::Error::InvalidStatusCode(
1282+
http::StatusCode::NOT_IMPLEMENTED,
1283+
)))
1284+
}
1285+
}
1286+
1287+
#[tokio::test]
1288+
async fn test_list_models_uses_models_endpoint() {
1289+
let response_body = r#"{
1290+
"object": "list",
1291+
"data": [
1292+
{
1293+
"id": "deepseek-v4-flash",
1294+
"object": "model",
1295+
"owned_by": "deepseek"
1296+
},
1297+
{
1298+
"id": "deepseek-v4-pro",
1299+
"object": "model",
1300+
"owned_by": "deepseek"
1301+
}
1302+
]
1303+
}"#;
1304+
1305+
let http_client = RecordingHttpClient::new(response_body);
1306+
let client = Client::builder()
1307+
.api_key("dummy-key")
1308+
.http_client(http_client.clone())
1309+
.build()
1310+
.expect("client should build");
1311+
1312+
let models = client
1313+
.list_models()
1314+
.await
1315+
.expect("list_models should succeed");
1316+
1317+
assert_eq!(models.len(), 2);
1318+
assert_eq!(models.data[0].id, "deepseek-v4-flash");
1319+
assert_eq!(models.data[0].r#type, None);
1320+
assert_eq!(models.data[0].owned_by.as_deref(), Some("deepseek"));
1321+
assert_eq!(
1322+
http_client.requests(),
1323+
vec![CapturedRequest {
1324+
uri: "https://api.deepseek.com/models".to_string()
1325+
}]
1326+
);
1327+
}
1328+
1329+
#[tokio::test]
1330+
async fn test_list_models_preserves_api_error_context() {
1331+
let http_client = RecordingHttpClient::with_error(
1332+
http::StatusCode::UNAUTHORIZED,
1333+
r#"{"error":{"message":"invalid api key"}}"#,
1334+
);
1335+
let client = Client::builder()
1336+
.api_key("dummy-key")
1337+
.http_client(http_client)
1338+
.build()
1339+
.expect("client should build");
1340+
1341+
let error = client
1342+
.list_models()
1343+
.await
1344+
.expect_err("list_models should fail");
1345+
1346+
match error {
1347+
ModelListingError::ApiError {
1348+
status_code,
1349+
message,
1350+
} => {
1351+
assert_eq!(status_code, 401);
1352+
assert!(message.contains("provider=DeepSeek"));
1353+
assert!(message.contains("path=/models"));
1354+
assert!(message.contains("invalid api key"));
1355+
}
1356+
other => panic!("expected api error, got {other:?}"),
1357+
}
1358+
}
10731359
}

0 commit comments

Comments
 (0)