@@ -14,19 +14,21 @@ use http::Request;
1414use tracing:: { Instrument , Level , enabled, info_span} ;
1515
1616use crate :: client:: {
17- self , BearerAuth , Capabilities , Capable , DebugExt , Nothing , Provider , ProviderBuilder ,
18- ProviderClient ,
17+ self , BearerAuth , Capabilities , Capable , DebugExt , ModelLister , Nothing , Provider ,
18+ ProviderBuilder , ProviderClient ,
1919} ;
2020use crate :: completion:: GetTokenUsage ;
2121use crate :: http_client:: { self , HttpClientExt } ;
2222use crate :: message:: { Document , DocumentSourceKind } ;
23+ use crate :: model:: { Model , ModelList , ModelListingError } ;
2324use crate :: providers:: internal:: openai_chat_completions_compatible:: {
2425 self , CompatibleChoiceData , CompatibleChunk , CompatibleFinishReason , CompatibleStreamProfile ,
2526} ;
2627use crate :: {
2728 OneOrMany ,
2829 completion:: { self , CompletionError , CompletionRequest } ,
2930 json_utils, message,
31+ wasm_compat:: { WasmCompatSend , WasmCompatSync } ,
3032} ;
3133use serde:: { Deserialize , Serialize } ;
3234
@@ -53,7 +55,7 @@ impl<H> Capabilities<H> for DeepSeekExt {
5355 type Completion = Capable < CompletionModel < H > > ;
5456 type Embeddings = Nothing ;
5557 type Transcription = Nothing ;
56- type ModelListing = Nothing ;
58+ type ModelListing = Capable < DeepSeekModelLister < H > > ;
5759 #[ cfg( feature = "image" ) ]
5860 type ImageGeneration = Nothing ;
5961 #[ cfg( feature = "audio" ) ]
@@ -791,6 +793,82 @@ where
791793 . await
792794}
793795
796+ #[ derive( Debug , Deserialize ) ]
797+ struct ListModelsResponse {
798+ data : Vec < ListModelEntry > ,
799+ }
800+
801+ #[ derive( Debug , Deserialize ) ]
802+ struct ListModelEntry {
803+ id : String ,
804+ owned_by : String ,
805+ }
806+
807+ impl From < ListModelEntry > for Model {
808+ fn from ( value : ListModelEntry ) -> Self {
809+ let mut model = Model :: from_id ( value. id ) ;
810+ model. owned_by = Some ( value. owned_by ) ;
811+ model
812+ }
813+ }
814+
815+ /// [`ModelLister`] implementation for the DeepSeek API (`GET /models`).
816+ #[ derive( Clone ) ]
817+ pub struct DeepSeekModelLister < H = reqwest:: Client > {
818+ client : Client < H > ,
819+ }
820+
821+ impl < H > ModelLister < H > for DeepSeekModelLister < H >
822+ where
823+ H : HttpClientExt + WasmCompatSend + WasmCompatSync + ' static ,
824+ {
825+ type Client = Client < H > ;
826+
827+ fn new ( client : Self :: Client ) -> Self {
828+ Self { client }
829+ }
830+
831+ async fn list_all ( & self ) -> Result < ModelList , ModelListingError > {
832+ let path = "/models" ;
833+ let req = self . client . get ( path) ?. body ( http_client:: NoBody ) ?;
834+ let response = self
835+ . client
836+ . send :: < _ , Vec < u8 > > ( req)
837+ . await
838+ . map_err ( |error| match error {
839+ http_client:: Error :: InvalidStatusCodeWithMessage ( status, message) => {
840+ ModelListingError :: api_error_with_context (
841+ "DeepSeek" ,
842+ path,
843+ status. as_u16 ( ) ,
844+ message. as_bytes ( ) ,
845+ )
846+ }
847+ other => ModelListingError :: from ( other) ,
848+ } ) ?;
849+
850+ if !response. status ( ) . is_success ( ) {
851+ let status_code = response. status ( ) . as_u16 ( ) ;
852+ let body = response. into_body ( ) . await ?;
853+ return Err ( ModelListingError :: api_error_with_context (
854+ "DeepSeek" ,
855+ path,
856+ status_code,
857+ & body,
858+ ) ) ;
859+ }
860+
861+ let body = response. into_body ( ) . await ?;
862+ let api_resp: ListModelsResponse = serde_json:: from_slice ( & body) . map_err ( |error| {
863+ ModelListingError :: parse_error_with_context ( "DeepSeek" , path, & error, & body)
864+ } ) ?;
865+
866+ let models = api_resp. data . into_iter ( ) . map ( Model :: from) . collect ( ) ;
867+
868+ Ok ( ModelList :: new ( models) )
869+ }
870+ }
871+
794872// ================================================================
795873// DeepSeek Completion API
796874// ================================================================
@@ -813,6 +891,11 @@ pub const DEEPSEEK_V4_PRO: &str = "deepseek-v4-pro";
813891#[ cfg( test) ]
814892mod tests {
815893 use super :: * ;
894+ use crate :: client:: ModelListingClient ;
895+ use crate :: http_client:: { LazyBody , MultipartForm , Request as HttpRequest , Response } ;
896+ use bytes:: Bytes ;
897+ use std:: future:: { self , Future } ;
898+ use std:: sync:: { Arc , Mutex } ;
816899
817900 #[ test]
818901 fn test_deserialize_vec_choice ( ) {
@@ -1070,4 +1153,207 @@ mod tests {
10701153 . build ( )
10711154 . expect ( "Client::builder() failed" ) ;
10721155 }
1156+
1157+ #[ test]
1158+ fn test_deserialize_list_models_response ( ) {
1159+ let data = r#"{
1160+ "object": "list",
1161+ "data": [
1162+ {
1163+ "id": "deepseek-v4-flash",
1164+ "object": "model",
1165+ "owned_by": "deepseek"
1166+ },
1167+ {
1168+ "id": "deepseek-v4-pro",
1169+ "object": "model",
1170+ "owned_by": "deepseek"
1171+ }
1172+ ]
1173+ }"# ;
1174+
1175+ let response: ListModelsResponse = serde_json:: from_str ( data) . unwrap ( ) ;
1176+
1177+ assert_eq ! ( response. data. len( ) , 2 ) ;
1178+ assert_eq ! ( response. data[ 0 ] . id, "deepseek-v4-flash" ) ;
1179+ assert_eq ! ( response. data[ 0 ] . owned_by, "deepseek" ) ;
1180+ }
1181+
1182+ #[ derive( Debug , Clone , PartialEq , Eq ) ]
1183+ struct CapturedRequest {
1184+ uri : String ,
1185+ }
1186+
1187+ #[ derive( Clone ) ]
1188+ enum MockResponse {
1189+ Success ( Bytes ) ,
1190+ Error ( http:: StatusCode , String ) ,
1191+ }
1192+
1193+ impl Default for MockResponse {
1194+ fn default ( ) -> Self {
1195+ Self :: Success ( Bytes :: new ( ) )
1196+ }
1197+ }
1198+
1199+ #[ derive( Clone , Default ) ]
1200+ struct RecordingHttpClient {
1201+ requests : Arc < Mutex < Vec < CapturedRequest > > > ,
1202+ response : Arc < Mutex < MockResponse > > ,
1203+ }
1204+
1205+ impl RecordingHttpClient {
1206+ fn new ( response_body : impl Into < Bytes > ) -> Self {
1207+ Self {
1208+ requests : Arc :: new ( Mutex :: new ( Vec :: new ( ) ) ) ,
1209+ response : Arc :: new ( Mutex :: new ( MockResponse :: Success ( response_body. into ( ) ) ) ) ,
1210+ }
1211+ }
1212+
1213+ fn with_error ( status : http:: StatusCode , message : impl Into < String > ) -> Self {
1214+ Self {
1215+ requests : Arc :: new ( Mutex :: new ( Vec :: new ( ) ) ) ,
1216+ response : Arc :: new ( Mutex :: new ( MockResponse :: Error ( status, message. into ( ) ) ) ) ,
1217+ }
1218+ }
1219+
1220+ fn requests ( & self ) -> Vec < CapturedRequest > {
1221+ self . requests . lock ( ) . expect ( "requests lock" ) . clone ( )
1222+ }
1223+ }
1224+
1225+ impl HttpClientExt for RecordingHttpClient {
1226+ fn send < T , U > (
1227+ & self ,
1228+ req : HttpRequest < T > ,
1229+ ) -> impl Future < Output = http_client:: Result < Response < LazyBody < U > > > > + WasmCompatSend + ' static
1230+ where
1231+ T : Into < Bytes > + WasmCompatSend ,
1232+ U : From < Bytes > + WasmCompatSend + ' static ,
1233+ {
1234+ let requests = Arc :: clone ( & self . requests ) ;
1235+ let response = self . response . lock ( ) . expect ( "response lock" ) . clone ( ) ;
1236+ let ( parts, _body) = req. into_parts ( ) ;
1237+
1238+ requests
1239+ . lock ( )
1240+ . expect ( "requests lock" )
1241+ . push ( CapturedRequest {
1242+ uri : parts. uri . to_string ( ) ,
1243+ } ) ;
1244+
1245+ async move {
1246+ let response_body = match response {
1247+ MockResponse :: Success ( response_body) => response_body,
1248+ MockResponse :: Error ( status, message) => {
1249+ return Err ( http_client:: Error :: InvalidStatusCodeWithMessage (
1250+ status, message,
1251+ ) ) ;
1252+ }
1253+ } ;
1254+ let body: LazyBody < U > = Box :: pin ( async move { Ok ( U :: from ( response_body) ) } ) ;
1255+ Response :: builder ( )
1256+ . status ( http:: StatusCode :: OK )
1257+ . body ( body)
1258+ . map_err ( http_client:: Error :: Protocol )
1259+ }
1260+ }
1261+
1262+ fn send_multipart < U > (
1263+ & self ,
1264+ _req : HttpRequest < MultipartForm > ,
1265+ ) -> impl Future < Output = http_client:: Result < Response < LazyBody < U > > > > + WasmCompatSend + ' static
1266+ where
1267+ U : From < Bytes > + WasmCompatSend + ' static ,
1268+ {
1269+ future:: ready ( Err ( http_client:: Error :: InvalidStatusCode (
1270+ http:: StatusCode :: NOT_IMPLEMENTED ,
1271+ ) ) )
1272+ }
1273+
1274+ fn send_streaming < T > (
1275+ & self ,
1276+ _req : HttpRequest < T > ,
1277+ ) -> impl Future < Output = http_client:: Result < http_client:: StreamingResponse > > + WasmCompatSend
1278+ where
1279+ T : Into < Bytes > + WasmCompatSend ,
1280+ {
1281+ future:: ready ( Err ( http_client:: Error :: InvalidStatusCode (
1282+ http:: StatusCode :: NOT_IMPLEMENTED ,
1283+ ) ) )
1284+ }
1285+ }
1286+
1287+ #[ tokio:: test]
1288+ async fn test_list_models_uses_models_endpoint ( ) {
1289+ let response_body = r#"{
1290+ "object": "list",
1291+ "data": [
1292+ {
1293+ "id": "deepseek-v4-flash",
1294+ "object": "model",
1295+ "owned_by": "deepseek"
1296+ },
1297+ {
1298+ "id": "deepseek-v4-pro",
1299+ "object": "model",
1300+ "owned_by": "deepseek"
1301+ }
1302+ ]
1303+ }"# ;
1304+
1305+ let http_client = RecordingHttpClient :: new ( response_body) ;
1306+ let client = Client :: builder ( )
1307+ . api_key ( "dummy-key" )
1308+ . http_client ( http_client. clone ( ) )
1309+ . build ( )
1310+ . expect ( "client should build" ) ;
1311+
1312+ let models = client
1313+ . list_models ( )
1314+ . await
1315+ . expect ( "list_models should succeed" ) ;
1316+
1317+ assert_eq ! ( models. len( ) , 2 ) ;
1318+ assert_eq ! ( models. data[ 0 ] . id, "deepseek-v4-flash" ) ;
1319+ assert_eq ! ( models. data[ 0 ] . r#type, None ) ;
1320+ assert_eq ! ( models. data[ 0 ] . owned_by. as_deref( ) , Some ( "deepseek" ) ) ;
1321+ assert_eq ! (
1322+ http_client. requests( ) ,
1323+ vec![ CapturedRequest {
1324+ uri: "https://api.deepseek.com/models" . to_string( )
1325+ } ]
1326+ ) ;
1327+ }
1328+
1329+ #[ tokio:: test]
1330+ async fn test_list_models_preserves_api_error_context ( ) {
1331+ let http_client = RecordingHttpClient :: with_error (
1332+ http:: StatusCode :: UNAUTHORIZED ,
1333+ r#"{"error":{"message":"invalid api key"}}"# ,
1334+ ) ;
1335+ let client = Client :: builder ( )
1336+ . api_key ( "dummy-key" )
1337+ . http_client ( http_client)
1338+ . build ( )
1339+ . expect ( "client should build" ) ;
1340+
1341+ let error = client
1342+ . list_models ( )
1343+ . await
1344+ . expect_err ( "list_models should fail" ) ;
1345+
1346+ match error {
1347+ ModelListingError :: ApiError {
1348+ status_code,
1349+ message,
1350+ } => {
1351+ assert_eq ! ( status_code, 401 ) ;
1352+ assert ! ( message. contains( "provider=DeepSeek" ) ) ;
1353+ assert ! ( message. contains( "path=/models" ) ) ;
1354+ assert ! ( message. contains( "invalid api key" ) ) ;
1355+ }
1356+ other => panic ! ( "expected api error, got {other:?}" ) ,
1357+ }
1358+ }
10731359}
0 commit comments