@@ -57,20 +57,31 @@ struct ErrorResponse {
5757 error : String ,
5858}
5959
60- /// Represents a model available through a provider
61- #[ derive( Debug , Serialize , Deserialize ) ]
60+ /// Common model information
61+ #[ derive( Debug , Serialize , Deserialize , Clone ) ]
6262pub struct Model {
63- /// Unique identifier of the model
64- pub name : String ,
63+ /// The model identifier
64+ pub id : String ,
65+ /// The object type, usually "model"
66+ pub object : Option < String > ,
67+ /// The Unix timestamp (in seconds) of when the model was created
68+ pub created : Option < i64 > ,
69+ /// The organization that owns the model
70+ pub owned_by : Option < String > ,
71+ /// The provider that serves the model
72+ pub served_by : Option < String > ,
6573}
6674
67- /// Collection of models available from a specific provider
75+ /// Response structure for listing models
6876#[ derive( Debug , Serialize , Deserialize ) ]
69- pub struct ProviderModels {
70- /// The LLM provider
71- pub provider : Provider ,
77+ pub struct ListModelsResponse {
78+ /// The provider identifier
79+ #[ serde( skip_serializing_if = "Option::is_none" ) ]
80+ pub provider : Option < Provider > ,
81+ /// Response object type, usually "list"
82+ pub object : String ,
7283 /// List of available models
73- pub models : Vec < Model > ,
84+ pub data : Vec < Model > ,
7485}
7586
7687/// Supported LLM providers
@@ -364,8 +375,7 @@ pub trait InferenceGatewayAPI {
364375 ///
365376 /// # Returns
366377 /// A list of models available from all providers
367- fn list_models ( & self )
368- -> impl Future < Output = Result < Vec < ProviderModels > , GatewayError > > + Send ;
378+ fn list_models ( & self ) -> impl Future < Output = Result < ListModelsResponse , GatewayError > > + Send ;
369379
370380 /// Lists available models by a specific provider
371381 ///
@@ -383,7 +393,7 @@ pub trait InferenceGatewayAPI {
383393 fn list_models_by_provider (
384394 & self ,
385395 provider : Provider ,
386- ) -> impl Future < Output = Result < ProviderModels , GatewayError > > + Send ;
396+ ) -> impl Future < Output = Result < ListModelsResponse , GatewayError > > + Send ;
387397
388398 /// Generates content using a specified model
389399 ///
@@ -501,7 +511,7 @@ impl InferenceGatewayClient {
501511}
502512
503513impl InferenceGatewayAPI for InferenceGatewayClient {
504- async fn list_models ( & self ) -> Result < Vec < ProviderModels > , GatewayError > {
514+ async fn list_models ( & self ) -> Result < ListModelsResponse , GatewayError > {
505515 let url = format ! ( "{}/models" , self . base_url) ;
506516 let mut request = self . client . get ( & url) ;
507517 if let Some ( token) = & self . token {
@@ -510,7 +520,10 @@ impl InferenceGatewayAPI for InferenceGatewayClient {
510520
511521 let response = request. send ( ) . await ?;
512522 match response. status ( ) {
513- StatusCode :: OK => Ok ( response. json ( ) . await ?) ,
523+ StatusCode :: OK => {
524+ let json_response: ListModelsResponse = response. json ( ) . await ?;
525+ Ok ( json_response)
526+ }
514527 StatusCode :: UNAUTHORIZED => {
515528 let error: ErrorResponse = response. json ( ) . await ?;
516529 Err ( GatewayError :: Unauthorized ( error. error ) )
@@ -533,7 +546,7 @@ impl InferenceGatewayAPI for InferenceGatewayClient {
533546 async fn list_models_by_provider (
534547 & self ,
535548 provider : Provider ,
536- ) -> Result < ProviderModels , GatewayError > {
549+ ) -> Result < ListModelsResponse , GatewayError > {
537550 let url = format ! ( "{}/list/models?provider={}" , self . base_url, provider) ;
538551 let mut request = self . client . get ( & url) ;
539552 if let Some ( token) = & self . token {
@@ -542,7 +555,10 @@ impl InferenceGatewayAPI for InferenceGatewayClient {
542555
543556 let response = request. send ( ) . await ?;
544557 match response. status ( ) {
545- StatusCode :: OK => Ok ( response. json ( ) . await ?) ,
558+ StatusCode :: OK => {
559+ let json_response: ListModelsResponse = response. json ( ) . await ?;
560+ Ok ( json_response)
561+ }
546562 StatusCode :: UNAUTHORIZED => {
547563 let error: ErrorResponse = response. json ( ) . await ?;
548564 Err ( GatewayError :: Unauthorized ( error. error ) )
@@ -848,12 +864,17 @@ mod tests {
848864 async fn test_authentication_header ( ) -> Result < ( ) , GatewayError > {
849865 let mut server = Server :: new_async ( ) . await ;
850866
867+ let mock_response = r#"{
868+ "object": "list",
869+ "data": []
870+ }"# ;
871+
851872 let mock_with_auth = server
852873 . mock ( "GET" , "/v1/models" )
853874 . match_header ( "authorization" , "Bearer test-token" )
854875 . with_status ( 200 )
855876 . with_header ( "content-type" , "application/json" )
856- . with_body ( "[]" )
877+ . with_body ( mock_response )
857878 . expect ( 1 )
858879 . create ( ) ;
859880
@@ -867,7 +888,7 @@ mod tests {
867888 . match_header ( "authorization" , Matcher :: Missing )
868889 . with_status ( 200 )
869890 . with_header ( "content-type" , "application/json" )
870- . with_body ( "[]" )
891+ . with_body ( mock_response )
871892 . expect ( 1 )
872893 . create ( ) ;
873894
@@ -911,14 +932,18 @@ mod tests {
911932 async fn test_list_models ( ) -> Result < ( ) , GatewayError > {
912933 let mut server = Server :: new_async ( ) . await ;
913934
914- let raw_response_json = r#"[
915- {
916- "provider": "ollama",
917- "models": [
918- {"name": "llama2"}
919- ]
920- }
921- ]"# ;
935+ let raw_response_json = r#"{
936+ "object": "list",
937+ "data": [
938+ {
939+ "id": "llama2",
940+ "object": "model",
941+ "created": 1630000001,
942+ "owned_by": "ollama",
943+ "served_by": "ollama"
944+ }
945+ ]
946+ }"# ;
922947
923948 let mock = server
924949 . mock ( "GET" , "/v1/models" )
@@ -929,10 +954,12 @@ mod tests {
929954
930955 let base_url = format ! ( "{}/v1" , server. url( ) ) ;
931956 let client = InferenceGatewayClient :: new ( & base_url) ;
932- let models = client. list_models ( ) . await ?;
957+ let response = client. list_models ( ) . await ?;
933958
934- assert_eq ! ( models. len( ) , 1 ) ;
935- assert_eq ! ( models[ 0 ] . models[ 0 ] . name, "llama2" ) ;
959+ assert ! ( response. provider. is_none( ) ) ;
960+ assert_eq ! ( response. object, "list" ) ;
961+ assert_eq ! ( response. data. len( ) , 1 ) ;
962+ assert_eq ! ( response. data[ 0 ] . id, "llama2" ) ;
936963 mock. assert ( ) ;
937964
938965 Ok ( ( ) )
@@ -944,9 +971,16 @@ mod tests {
944971
945972 let raw_json_response = r#"{
946973 "provider":"ollama",
947- "models": [{
948- "name": "llama2"
949- }]
974+ "object":"list",
975+ "data": [
976+ {
977+ "id": "llama2",
978+ "object": "model",
979+ "created": 1630000001,
980+ "owned_by": "ollama",
981+ "served_by": "ollama"
982+ }
983+ ]
950984 }"# ;
951985
952986 let mock = server
@@ -958,10 +992,11 @@ mod tests {
958992
959993 let base_url = format ! ( "{}/v1" , server. url( ) ) ;
960994 let client = InferenceGatewayClient :: new ( & base_url) ;
961- let models = client. list_models_by_provider ( Provider :: Ollama ) . await ?;
995+ let response = client. list_models_by_provider ( Provider :: Ollama ) . await ?;
962996
963- assert_eq ! ( models. provider, Provider :: Ollama ) ;
964- assert_eq ! ( models. models[ 0 ] . name, "llama2" ) ;
997+ assert ! ( response. provider. is_some( ) ) ;
998+ assert_eq ! ( response. provider, Some ( Provider :: Ollama ) ) ;
999+ assert_eq ! ( response. data[ 0 ] . id, "llama2" ) ;
9651000 mock. assert ( ) ;
9661001
9671002 Ok ( ( ) )
0 commit comments