Skip to content

Commit 5ca2eb5

Browse files
committed
refactor: Update model response structures for clarity and consistency
Signed-off-by: Eden Reich <eden.reich@gmail.com>
1 parent 6079995 commit 5ca2eb5

File tree

1 file changed

+70
-35
lines changed

1 file changed

+70
-35
lines changed

src/lib.rs

Lines changed: 70 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -57,20 +57,31 @@ struct ErrorResponse {
5757
error: String,
5858
}
5959

60-
/// Represents a model available through a provider
61-
#[derive(Debug, Serialize, Deserialize)]
60+
/// Common model information
61+
#[derive(Debug, Serialize, Deserialize, Clone)]
6262
pub struct Model {
63-
/// Unique identifier of the model
64-
pub name: String,
63+
/// The model identifier
64+
pub id: String,
65+
/// The object type, usually "model"
66+
pub object: Option<String>,
67+
/// The Unix timestamp (in seconds) of when the model was created
68+
pub created: Option<i64>,
69+
/// The organization that owns the model
70+
pub owned_by: Option<String>,
71+
/// The provider that serves the model
72+
pub served_by: Option<String>,
6573
}
6674

67-
/// Collection of models available from a specific provider
75+
/// Response structure for listing models
6876
#[derive(Debug, Serialize, Deserialize)]
69-
pub struct ProviderModels {
70-
/// The LLM provider
71-
pub provider: Provider,
77+
pub struct ListModelsResponse {
78+
/// The provider identifier
79+
#[serde(skip_serializing_if = "Option::is_none")]
80+
pub provider: Option<Provider>,
81+
/// Response object type, usually "list"
82+
pub object: String,
7283
/// List of available models
73-
pub models: Vec<Model>,
84+
pub data: Vec<Model>,
7485
}
7586

7687
/// Supported LLM providers
@@ -364,8 +375,7 @@ pub trait InferenceGatewayAPI {
364375
///
365376
/// # Returns
366377
/// A list of models available from all providers
367-
fn list_models(&self)
368-
-> impl Future<Output = Result<Vec<ProviderModels>, GatewayError>> + Send;
378+
fn list_models(&self) -> impl Future<Output = Result<ListModelsResponse, GatewayError>> + Send;
369379

370380
/// Lists available models by a specific provider
371381
///
@@ -383,7 +393,7 @@ pub trait InferenceGatewayAPI {
383393
fn list_models_by_provider(
384394
&self,
385395
provider: Provider,
386-
) -> impl Future<Output = Result<ProviderModels, GatewayError>> + Send;
396+
) -> impl Future<Output = Result<ListModelsResponse, GatewayError>> + Send;
387397

388398
/// Generates content using a specified model
389399
///
@@ -501,7 +511,7 @@ impl InferenceGatewayClient {
501511
}
502512

503513
impl InferenceGatewayAPI for InferenceGatewayClient {
504-
async fn list_models(&self) -> Result<Vec<ProviderModels>, GatewayError> {
514+
async fn list_models(&self) -> Result<ListModelsResponse, GatewayError> {
505515
let url = format!("{}/models", self.base_url);
506516
let mut request = self.client.get(&url);
507517
if let Some(token) = &self.token {
@@ -510,7 +520,10 @@ impl InferenceGatewayAPI for InferenceGatewayClient {
510520

511521
let response = request.send().await?;
512522
match response.status() {
513-
StatusCode::OK => Ok(response.json().await?),
523+
StatusCode::OK => {
524+
let json_response: ListModelsResponse = response.json().await?;
525+
Ok(json_response)
526+
}
514527
StatusCode::UNAUTHORIZED => {
515528
let error: ErrorResponse = response.json().await?;
516529
Err(GatewayError::Unauthorized(error.error))
@@ -533,7 +546,7 @@ impl InferenceGatewayAPI for InferenceGatewayClient {
533546
async fn list_models_by_provider(
534547
&self,
535548
provider: Provider,
536-
) -> Result<ProviderModels, GatewayError> {
549+
) -> Result<ListModelsResponse, GatewayError> {
537550
let url = format!("{}/list/models?provider={}", self.base_url, provider);
538551
let mut request = self.client.get(&url);
539552
if let Some(token) = &self.token {
@@ -542,7 +555,10 @@ impl InferenceGatewayAPI for InferenceGatewayClient {
542555

543556
let response = request.send().await?;
544557
match response.status() {
545-
StatusCode::OK => Ok(response.json().await?),
558+
StatusCode::OK => {
559+
let json_response: ListModelsResponse = response.json().await?;
560+
Ok(json_response)
561+
}
546562
StatusCode::UNAUTHORIZED => {
547563
let error: ErrorResponse = response.json().await?;
548564
Err(GatewayError::Unauthorized(error.error))
@@ -848,12 +864,17 @@ mod tests {
848864
async fn test_authentication_header() -> Result<(), GatewayError> {
849865
let mut server = Server::new_async().await;
850866

867+
let mock_response = r#"{
868+
"object": "list",
869+
"data": []
870+
}"#;
871+
851872
let mock_with_auth = server
852873
.mock("GET", "/v1/models")
853874
.match_header("authorization", "Bearer test-token")
854875
.with_status(200)
855876
.with_header("content-type", "application/json")
856-
.with_body("[]")
877+
.with_body(mock_response)
857878
.expect(1)
858879
.create();
859880

@@ -867,7 +888,7 @@ mod tests {
867888
.match_header("authorization", Matcher::Missing)
868889
.with_status(200)
869890
.with_header("content-type", "application/json")
870-
.with_body("[]")
891+
.with_body(mock_response)
871892
.expect(1)
872893
.create();
873894

@@ -911,14 +932,18 @@ mod tests {
911932
async fn test_list_models() -> Result<(), GatewayError> {
912933
let mut server = Server::new_async().await;
913934

914-
let raw_response_json = r#"[
915-
{
916-
"provider": "ollama",
917-
"models": [
918-
{"name": "llama2"}
919-
]
920-
}
921-
]"#;
935+
let raw_response_json = r#"{
936+
"object": "list",
937+
"data": [
938+
{
939+
"id": "llama2",
940+
"object": "model",
941+
"created": 1630000001,
942+
"owned_by": "ollama",
943+
"served_by": "ollama"
944+
}
945+
]
946+
}"#;
922947

923948
let mock = server
924949
.mock("GET", "/v1/models")
@@ -929,10 +954,12 @@ mod tests {
929954

930955
let base_url = format!("{}/v1", server.url());
931956
let client = InferenceGatewayClient::new(&base_url);
932-
let models = client.list_models().await?;
957+
let response = client.list_models().await?;
933958

934-
assert_eq!(models.len(), 1);
935-
assert_eq!(models[0].models[0].name, "llama2");
959+
assert!(response.provider.is_none());
960+
assert_eq!(response.object, "list");
961+
assert_eq!(response.data.len(), 1);
962+
assert_eq!(response.data[0].id, "llama2");
936963
mock.assert();
937964

938965
Ok(())
@@ -944,9 +971,16 @@ mod tests {
944971

945972
let raw_json_response = r#"{
946973
"provider":"ollama",
947-
"models": [{
948-
"name": "llama2"
949-
}]
974+
"object":"list",
975+
"data": [
976+
{
977+
"id": "llama2",
978+
"object": "model",
979+
"created": 1630000001,
980+
"owned_by": "ollama",
981+
"served_by": "ollama"
982+
}
983+
]
950984
}"#;
951985

952986
let mock = server
@@ -958,10 +992,11 @@ mod tests {
958992

959993
let base_url = format!("{}/v1", server.url());
960994
let client = InferenceGatewayClient::new(&base_url);
961-
let models = client.list_models_by_provider(Provider::Ollama).await?;
995+
let response = client.list_models_by_provider(Provider::Ollama).await?;
962996

963-
assert_eq!(models.provider, Provider::Ollama);
964-
assert_eq!(models.models[0].name, "llama2");
997+
assert!(response.provider.is_some());
998+
assert_eq!(response.provider, Some(Provider::Ollama));
999+
assert_eq!(response.data[0].id, "llama2");
9651000
mock.assert();
9661001

9671002
Ok(())

0 commit comments

Comments
 (0)