From a53e5a1bbef29f70d2168e79e1d5edb27c119f5d Mon Sep 17 00:00:00 2001 From: Arya Soni Date: Sat, 4 Oct 2025 11:38:48 +0530 Subject: [PATCH 01/10] Feature:Add tests for json_schema.rs --- Cargo.lock | 17 + Cargo.toml | 1 + src/base/json_schema.rs | 978 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 996 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index a01a63d3..8e58c5bb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1306,6 +1306,7 @@ dependencies = [ "const_format", "derivative", "env_logger", + "expect-test", "futures", "globset", "google-cloud-aiplatform-v1", @@ -1791,6 +1792,12 @@ dependencies = [ "syn 2.0.105", ] +[[package]] +name = "dissimilar" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8975ffdaa0ef3661bfe02dbdcc06c9f829dfafe6a3c474de366a8d5e44276921" + [[package]] name = "dlv-list" version = "0.5.2" @@ -1966,6 +1973,16 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "expect-test" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63af43ff4431e848fb47472a920f14fa71c24de13255a5692e93d4e90302acb0" +dependencies = [ + "dissimilar", + "once_cell", +] + [[package]] name = "fastrand" version = "1.9.0" diff --git a/Cargo.toml b/Cargo.toml index 5cb5b461..c25a3bd1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -146,3 +146,4 @@ azure_storage_blobs = { version = "0.21.0", default-features = false, features = "hmac_rust", ] } serde_path_to_error = "0.1.17" +expect-test = "1.5.0" diff --git a/src/base/json_schema.rs b/src/base/json_schema.rs index 9cad4ad5..cd6d7ee7 100644 --- a/src/base/json_schema.rs +++ b/src/base/json_schema.rs @@ -352,3 +352,981 @@ pub fn build_json_schema( }, }) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::base::schema::*; + use expect_test::expect; + use serde_json::json; + use std::sync::Arc; + + fn create_test_options() -> ToJsonSchemaOptions { + ToJsonSchemaOptions { + fields_always_required: false, + supports_format: true, + extract_descriptions: false, + top_level_must_be_object: false, + } + } + + fn create_test_options_with_extracted_descriptions() -> ToJsonSchemaOptions { + ToJsonSchemaOptions { + fields_always_required: false, + supports_format: true, + extract_descriptions: true, + top_level_must_be_object: false, + } + } + + fn create_test_options_always_required() -> ToJsonSchemaOptions { + ToJsonSchemaOptions { + fields_always_required: true, + supports_format: true, + extract_descriptions: false, + top_level_must_be_object: false, + } + } + + fn create_test_options_top_level_object() -> ToJsonSchemaOptions { + ToJsonSchemaOptions { + fields_always_required: false, + supports_format: true, + extract_descriptions: false, + top_level_must_be_object: true, + } + } + + fn schema_to_json(schema: &SchemaObject) -> serde_json::Value { + serde_json::to_value(schema).unwrap() + } + + #[test] + fn test_basic_types_str() { + let value_type = EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Str), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "type": "string" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_basic_types_bool() { + let value_type = EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Bool), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "type": "boolean" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_basic_types_int64() { + let value_type = EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Int64), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "type": "integer" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_basic_types_float32() { + let value_type = EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Float32), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "type": "number" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_basic_types_float64() { + let value_type = EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Float64), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "type": "number" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_basic_types_bytes() { + let value_type = EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Bytes), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "type": "string" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_basic_types_range() { + let value_type = EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Range), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "description": "A range represented by a list of two positions, start pos (inclusive), end pos (exclusive).", + "items": { + "type": "integer" + }, + "maxItems": 2, + "minItems": 2, + "type": "array" + }"#]].assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_basic_types_uuid() { + let value_type = EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Uuid), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "description": "A UUID, e.g. 123e4567-e89b-12d3-a456-426614174000", + "format": "uuid", + "type": "string" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_basic_types_date() { + let value_type = EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Date), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "description": "A date in YYYY-MM-DD format, e.g. 2025-03-27", + "format": "date", + "type": "string" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_basic_types_time() { + let value_type = EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Time), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "description": "A time in HH:MM:SS format, e.g. 13:32:12", + "format": "time", + "type": "string" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_basic_types_local_date_time() { + let value_type = EnrichedValueType { + typ: ValueType::Basic(BasicValueType::LocalDateTime), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "description": "Date time without timezone offset in YYYY-MM-DDTHH:MM:SS format, e.g. 2025-03-27T13:32:12", + "format": "date-time", + "type": "string" + }"#]].assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_basic_types_offset_date_time() { + let value_type = EnrichedValueType { + typ: ValueType::Basic(BasicValueType::OffsetDateTime), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "description": "Date time with timezone offset in RFC3339, e.g. 2025-03-27T13:32:12Z, 2025-03-27T07:32:12.313-06:00", + "format": "date-time", + "type": "string" + }"#]].assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_basic_types_time_delta() { + let value_type = EnrichedValueType { + typ: ValueType::Basic(BasicValueType::TimeDelta), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "description": "A duration, e.g. 'PT1H2M3S' (ISO 8601) or '1 day 2 hours 3 seconds'", + "format": "duration", + "type": "string" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_basic_types_json() { + let value_type = EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Json), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect!["{}"].assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_basic_types_vector() { + let value_type = EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Vector(VectorTypeSchema { + element_type: Box::new(BasicValueType::Str), + dimension: Some(3), + })), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "items": { + "type": "string" + }, + "maxItems": 3, + "minItems": 3, + "type": "array" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_basic_types_union() { + let value_type = EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Union(UnionTypeSchema { + types: vec![BasicValueType::Str, BasicValueType::Int64], + })), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "oneOf": [ + { + "type": "string" + }, + { + "type": "integer" + } + ] + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_nullable_basic_type() { + let value_type = EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Str), + nullable: true, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "type": "string" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_struct_type_simple() { + let value_type = EnrichedValueType { + typ: ValueType::Struct(StructSchema { + fields: Arc::new(vec![ + FieldSchema::new( + "name", + EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Str), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }, + ), + FieldSchema::new( + "age", + EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Int64), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }, + ), + ]), + description: None, + }), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "additionalProperties": false, + "properties": { + "age": { + "type": "integer" + }, + "name": { + "type": "string" + } + }, + "required": [ + "age", + "name" + ], + "type": "object" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_struct_type_with_optional_field() { + let value_type = EnrichedValueType { + typ: ValueType::Struct(StructSchema { + fields: Arc::new(vec![ + FieldSchema::new( + "name", + EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Str), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }, + ), + FieldSchema::new( + "age", + EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Int64), + nullable: true, + attrs: Arc::new(BTreeMap::new()), + }, + ), + ]), + description: None, + }), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "additionalProperties": false, + "properties": { + "age": { + "type": "integer" + }, + "name": { + "type": "string" + } + }, + "required": [ + "name" + ], + "type": "object" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_struct_type_with_description() { + let value_type = EnrichedValueType { + typ: ValueType::Struct(StructSchema { + fields: Arc::new(vec![FieldSchema::new( + "name", + EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Str), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }, + )]), + description: Some("A person".into()), + }), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "additionalProperties": false, + "description": "A person", + "properties": { + "name": { + "type": "string" + } + }, + "required": [ + "name" + ], + "type": "object" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_struct_type_with_extracted_descriptions() { + let value_type = EnrichedValueType { + typ: ValueType::Struct(StructSchema { + fields: Arc::new(vec![FieldSchema::new( + "name", + EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Str), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }, + )]), + description: Some("A person".into()), + }), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options_with_extracted_descriptions(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "additionalProperties": false, + "properties": { + "name": { + "type": "string" + } + }, + "required": [ + "name" + ], + "type": "object" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + + // Check that description was extracted to extra instructions + assert!(result.extra_instructions.is_some()); + let instructions = result.extra_instructions.unwrap(); + assert!(instructions.contains("A person")); + } + + #[test] + fn test_struct_type_always_required() { + let value_type = EnrichedValueType { + typ: ValueType::Struct(StructSchema { + fields: Arc::new(vec![ + FieldSchema::new( + "name", + EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Str), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }, + ), + FieldSchema::new( + "age", + EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Int64), + nullable: true, + attrs: Arc::new(BTreeMap::new()), + }, + ), + ]), + description: None, + }), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options_always_required(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "additionalProperties": false, + "properties": { + "age": { + "type": [ + "integer", + "null" + ] + }, + "name": { + "type": "string" + } + }, + "required": [ + "age", + "name" + ], + "type": "object" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_table_type_utable() { + let value_type = EnrichedValueType { + typ: ValueType::Table(TableSchema { + kind: TableKind::UTable, + row: StructSchema { + fields: Arc::new(vec![ + FieldSchema::new( + "id", + EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Int64), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }, + ), + FieldSchema::new( + "name", + EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Str), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }, + ), + ]), + description: None, + }, + }), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "items": { + "additionalProperties": false, + "properties": { + "id": { + "type": "integer" + }, + "name": { + "type": "string" + } + }, + "required": [ + "id", + "name" + ], + "type": "object" + }, + "type": "array" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_table_type_ktable() { + let value_type = EnrichedValueType { + typ: ValueType::Table(TableSchema { + kind: TableKind::KTable(KTableInfo { num_key_parts: 1 }), + row: StructSchema { + fields: Arc::new(vec![ + FieldSchema::new( + "id", + EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Int64), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }, + ), + FieldSchema::new( + "name", + EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Str), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }, + ), + ]), + description: None, + }, + }), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "items": { + "additionalProperties": false, + "properties": { + "id": { + "type": "integer" + }, + "name": { + "type": "string" + } + }, + "required": [ + "id", + "name" + ], + "type": "object" + }, + "type": "array" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_table_type_ltable() { + let value_type = EnrichedValueType { + typ: ValueType::Table(TableSchema { + kind: TableKind::LTable, + row: StructSchema { + fields: Arc::new(vec![FieldSchema::new( + "value", + EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Str), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }, + )]), + description: None, + }, + }), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "items": { + "additionalProperties": false, + "properties": { + "value": { + "type": "string" + } + }, + "required": [ + "value" + ], + "type": "object" + }, + "type": "array" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_top_level_must_be_object_with_basic_type() { + let value_type = EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Str), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options_top_level_object(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "additionalProperties": false, + "properties": { + "value": { + "type": "string" + } + }, + "required": [ + "value" + ], + "type": "object" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + + // Check that value extractor has the wrapper field name + assert_eq!( + result.value_extractor.object_wrapper_field_name, + Some("value".to_string()) + ); + } + + #[test] + fn test_top_level_must_be_object_with_struct_type() { + let value_type = EnrichedValueType { + typ: ValueType::Struct(StructSchema { + fields: Arc::new(vec![FieldSchema::new( + "name", + EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Str), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }, + )]), + description: None, + }), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options_top_level_object(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "additionalProperties": false, + "properties": { + "name": { + "type": "string" + } + }, + "required": [ + "name" + ], + "type": "object" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + + // Check that value extractor has no wrapper field name since it's already a struct + assert_eq!(result.value_extractor.object_wrapper_field_name, None); + } + + #[test] + fn test_nested_struct() { + let value_type = EnrichedValueType { + typ: ValueType::Struct(StructSchema { + fields: Arc::new(vec![FieldSchema::new( + "person", + EnrichedValueType { + typ: ValueType::Struct(StructSchema { + fields: Arc::new(vec![ + FieldSchema::new( + "name", + EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Str), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }, + ), + FieldSchema::new( + "age", + EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Int64), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }, + ), + ]), + description: None, + }), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }, + )]), + description: None, + }), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "additionalProperties": false, + "properties": { + "person": { + "additionalProperties": false, + "properties": { + "age": { + "type": "integer" + }, + "name": { + "type": "string" + } + }, + "required": [ + "age", + "name" + ], + "type": "object" + } + }, + "required": [ + "person" + ], + "type": "object" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } + + #[test] + fn test_value_extractor_basic_type() { + let value_type = EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Str), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options(); + let result = build_json_schema(value_type, options).unwrap(); + + // Test extracting a string value + let json_value = json!("hello world"); + let extracted = result.value_extractor.extract_value(json_value).unwrap(); + assert!( + matches!(extracted, crate::base::value::Value::Basic(crate::base::value::BasicValue::Str(s)) if s.as_ref() == "hello world") + ); + } + + #[test] + fn test_value_extractor_with_wrapper() { + let value_type = EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Str), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = create_test_options_top_level_object(); + let result = build_json_schema(value_type, options).unwrap(); + + // Test extracting a wrapped value + let json_value = json!({"value": "hello world"}); + let extracted = result.value_extractor.extract_value(json_value).unwrap(); + assert!( + matches!(extracted, crate::base::value::Value::Basic(crate::base::value::BasicValue::Str(s)) if s.as_ref() == "hello world") + ); + } + + #[test] + fn test_no_format_support() { + let value_type = EnrichedValueType { + typ: ValueType::Basic(BasicValueType::Uuid), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + let options = ToJsonSchemaOptions { + fields_always_required: false, + supports_format: false, + extract_descriptions: false, + top_level_must_be_object: false, + }; + let result = build_json_schema(value_type, options).unwrap(); + let json_schema = schema_to_json(&result.schema); + + expect![[r#" + { + "description": "A UUID, e.g. 123e4567-e89b-12d3-a456-426614174000", + "type": "string" + }"#]] + .assert_eq(&serde_json::to_string_pretty(&json_schema).unwrap()); + } +} From 33fdd66b62c3bc3015b34e5afa10cd9d0baba614 Mon Sep 17 00:00:00 2001 From: Arya Soni Date: Sat, 4 Oct 2025 12:02:44 +0530 Subject: [PATCH 02/10] Feature:Enable programmatically pass in api_key besides reading from env --- python/cocoindex/functions.py | 1 + .../functions/_engine_builtin_specs.py | 1 + python/cocoindex/llm.py | 57 +++++++++++++++- src/llm/anthropic.rs | 17 +++-- src/llm/gemini.rs | 17 +++-- src/llm/litellm.rs | 10 ++- src/llm/mod.rs | 68 ++++++++++++++----- src/llm/openai.rs | 13 ++-- src/llm/openrouter.rs | 10 ++- src/llm/vllm.rs | 10 ++- src/llm/voyage.rs | 18 +++-- src/ops/functions/embed_text.rs | 55 ++++++++++++++- 12 files changed, 234 insertions(+), 43 deletions(-) diff --git a/python/cocoindex/functions.py b/python/cocoindex/functions.py index 13765907..64955d4f 100644 --- a/python/cocoindex/functions.py +++ b/python/cocoindex/functions.py @@ -67,6 +67,7 @@ class EmbedText(op.FunctionSpec): output_dimension: int | None = None task_type: str | None = None api_config: llm.VertexAiConfig | None = None + api_key: str | None = None class ExtractByLlm(op.FunctionSpec): diff --git a/python/cocoindex/functions/_engine_builtin_specs.py b/python/cocoindex/functions/_engine_builtin_specs.py index 29353f1e..4d046419 100644 --- a/python/cocoindex/functions/_engine_builtin_specs.py +++ b/python/cocoindex/functions/_engine_builtin_specs.py @@ -52,6 +52,7 @@ class EmbedText(op.FunctionSpec): output_dimension: int | None = None task_type: str | None = None api_config: llm.VertexAiConfig | None = None + api_key: str | None = None class ExtractByLlm(op.FunctionSpec): diff --git a/python/cocoindex/llm.py b/python/cocoindex/llm.py index 4dc98c42..0a1c42fa 100644 --- a/python/cocoindex/llm.py +++ b/python/cocoindex/llm.py @@ -34,6 +34,61 @@ class OpenAiConfig: org_id: str | None = None project_id: str | None = None + api_key: str | None = None + + +@dataclass +class AnthropicConfig: + """A specification for an Anthropic LLM.""" + + kind = "Anthropic" + + api_key: str | None = None + + +@dataclass +class GeminiConfig: + """A specification for a Gemini LLM.""" + + kind = "Gemini" + + api_key: str | None = None + + +@dataclass +class VoyageConfig: + """A specification for a Voyage LLM.""" + + kind = "Voyage" + + api_key: str | None = None + + +@dataclass +class LiteLlmConfig: + """A specification for a LiteLLM LLM.""" + + kind = "LiteLlm" + + api_key: str | None = None + + +@dataclass +class OpenRouterConfig: + """A specification for an OpenRouter LLM.""" + + kind = "OpenRouter" + + api_key: str | None = None + + +@dataclass +class VllmConfig: + """A specification for a VLLM LLM.""" + + kind = "Vllm" + + api_key: str | None = None @dataclass @@ -43,4 +98,4 @@ class LlmSpec: api_type: LlmApiType model: str address: str | None = None - api_config: VertexAiConfig | OpenAiConfig | None = None + api_config: VertexAiConfig | OpenAiConfig | AnthropicConfig | GeminiConfig | VoyageConfig | LiteLlmConfig | OpenRouterConfig | VllmConfig | None = None diff --git a/src/llm/anthropic.rs b/src/llm/anthropic.rs index d81f5d76..36dbbf61 100644 --- a/src/llm/anthropic.rs +++ b/src/llm/anthropic.rs @@ -14,14 +14,23 @@ pub struct Client { } impl Client { - pub async fn new(address: Option) -> Result { + pub async fn new(address: Option, api_config: Option) -> Result { if address.is_some() { api_bail!("Anthropic doesn't support custom API address"); } - let api_key = match std::env::var("ANTHROPIC_API_KEY") { - Ok(val) => val, - Err(_) => api_bail!("ANTHROPIC_API_KEY environment variable must be set"), + + let api_key = if let Some(super::LlmApiConfig::Anthropic(config)) = api_config { + if let Some(key) = config.api_key { + key + } else { + std::env::var("ANTHROPIC_API_KEY") + .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable must be set"))? + } + } else { + std::env::var("ANTHROPIC_API_KEY") + .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable must be set"))? }; + Ok(Self { api_key, client: reqwest::Client::new(), diff --git a/src/llm/gemini.rs b/src/llm/gemini.rs index bb2aeb0f..f1c7eb3a 100644 --- a/src/llm/gemini.rs +++ b/src/llm/gemini.rs @@ -30,14 +30,23 @@ pub struct AiStudioClient { } impl AiStudioClient { - pub fn new(address: Option) -> Result { + pub fn new(address: Option, api_config: Option) -> Result { if address.is_some() { api_bail!("Gemini doesn't support custom API address"); } - let api_key = match std::env::var("GEMINI_API_KEY") { - Ok(val) => val, - Err(_) => api_bail!("GEMINI_API_KEY environment variable must be set"), + + let api_key = if let Some(super::LlmApiConfig::Gemini(config)) = api_config { + if let Some(key) = config.api_key { + key + } else { + std::env::var("GEMINI_API_KEY") + .map_err(|_| anyhow::anyhow!("GEMINI_API_KEY environment variable must be set"))? + } + } else { + std::env::var("GEMINI_API_KEY") + .map_err(|_| anyhow::anyhow!("GEMINI_API_KEY environment variable must be set"))? }; + Ok(Self { api_key, client: reqwest::Client::new(), diff --git a/src/llm/litellm.rs b/src/llm/litellm.rs index 85d1b50e..cbca0bc7 100644 --- a/src/llm/litellm.rs +++ b/src/llm/litellm.rs @@ -4,9 +4,15 @@ use async_openai::config::OpenAIConfig; pub use super::openai::Client; impl Client { - pub async fn new_litellm(address: Option) -> anyhow::Result { + pub async fn new_litellm(address: Option, api_config: Option) -> anyhow::Result { let address = address.unwrap_or_else(|| "http://127.0.0.1:4000".to_string()); - let api_key = std::env::var("LITELLM_API_KEY").ok(); + + let api_key = if let Some(super::LlmApiConfig::LiteLlm(config)) = api_config { + config.api_key.or_else(|| std::env::var("LITELLM_API_KEY").ok()) + } else { + std::env::var("LITELLM_API_KEY").ok() + }; + let mut config = OpenAIConfig::new().with_api_base(address); if let Some(api_key) = api_key { config = config.with_api_key(api_key); diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 2145f2d1..75497de2 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -30,6 +30,37 @@ pub struct VertexAiConfig { pub struct OpenAiConfig { pub org_id: Option, pub project_id: Option, + pub api_key: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AnthropicConfig { + pub api_key: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GeminiConfig { + pub api_key: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VoyageConfig { + pub api_key: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LiteLlmConfig { + pub api_key: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpenRouterConfig { + pub api_key: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VllmConfig { + pub api_key: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -37,6 +68,12 @@ pub struct OpenAiConfig { pub enum LlmApiConfig { VertexAi(VertexAiConfig), OpenAi(OpenAiConfig), + Anthropic(AnthropicConfig), + Gemini(GeminiConfig), + Voyage(VoyageConfig), + LiteLlm(LiteLlmConfig), + OpenRouter(OpenRouterConfig), + Vllm(VllmConfig), } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -126,25 +163,23 @@ pub async fn new_llm_generation_client( LlmApiType::OpenAi => { Box::new(openai::Client::new(address, api_config)?) as Box } - LlmApiType::Gemini => { - Box::new(gemini::AiStudioClient::new(address)?) as Box - } + LlmApiType::Gemini => Box::new(gemini::AiStudioClient::new(address, api_config)?) + as Box, LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?) as Box, - LlmApiType::Anthropic => { - Box::new(anthropic::Client::new(address).await?) as Box - } - LlmApiType::LiteLlm => { - Box::new(litellm::Client::new_litellm(address).await?) as Box - } - LlmApiType::OpenRouter => Box::new(openrouter::Client::new_openrouter(address).await?) + LlmApiType::Anthropic => Box::new(anthropic::Client::new(address, api_config).await?) + as Box, + LlmApiType::LiteLlm => Box::new(litellm::Client::new_litellm(address, api_config).await?) as Box, + LlmApiType::OpenRouter => { + Box::new(openrouter::Client::new_openrouter(address, api_config).await?) + as Box + } LlmApiType::Voyage => { api_bail!("Voyage is not supported for generation") } - LlmApiType::Vllm => { - Box::new(vllm::Client::new_vllm(address).await?) as Box - } + LlmApiType::Vllm => Box::new(vllm::Client::new_vllm(address, api_config).await?) + as Box, }; Ok(client) } @@ -158,14 +193,13 @@ pub async fn new_llm_embedding_client( LlmApiType::Ollama => { Box::new(ollama::Client::new(address).await?) as Box } - LlmApiType::Gemini => { - Box::new(gemini::AiStudioClient::new(address)?) as Box - } + LlmApiType::Gemini => Box::new(gemini::AiStudioClient::new(address, api_config)?) + as Box, LlmApiType::OpenAi => { Box::new(openai::Client::new(address, api_config)?) as Box } LlmApiType::Voyage => { - Box::new(voyage::Client::new(address)?) as Box + Box::new(voyage::Client::new(address, api_config)?) as Box } LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?) as Box, diff --git a/src/llm/openai.rs b/src/llm/openai.rs index a29a9bce..f4715875 100644 --- a/src/llm/openai.rs +++ b/src/llm/openai.rs @@ -50,13 +50,16 @@ impl Client { if let Some(project_id) = config.project_id { openai_config = openai_config.with_project_id(project_id); } - - // Verify API key is set - if std::env::var("OPENAI_API_KEY").is_err() { - api_bail!("OPENAI_API_KEY environment variable must be set"); + if let Some(api_key) = config.api_key { + openai_config = openai_config.with_api_key(api_key); + } else { + // Verify API key is set in environment if not provided in config + if std::env::var("OPENAI_API_KEY").is_err() { + api_bail!("OPENAI_API_KEY environment variable must be set"); + } } + Ok(Self { - // OpenAI client will use OPENAI_API_KEY and OPENAI_API_BASE env variables by default client: OpenAIClient::with_config(openai_config), }) } diff --git a/src/llm/openrouter.rs b/src/llm/openrouter.rs index ecf4d0fa..016688e5 100644 --- a/src/llm/openrouter.rs +++ b/src/llm/openrouter.rs @@ -4,9 +4,15 @@ use async_openai::config::OpenAIConfig; pub use super::openai::Client; impl Client { - pub async fn new_openrouter(address: Option) -> anyhow::Result { + pub async fn new_openrouter(address: Option, api_config: Option) -> anyhow::Result { let address = address.unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string()); - let api_key = std::env::var("OPENROUTER_API_KEY").ok(); + + let api_key = if let Some(super::LlmApiConfig::OpenRouter(config)) = api_config { + config.api_key.or_else(|| std::env::var("OPENROUTER_API_KEY").ok()) + } else { + std::env::var("OPENROUTER_API_KEY").ok() + }; + let mut config = OpenAIConfig::new().with_api_base(address); if let Some(api_key) = api_key { config = config.with_api_key(api_key); diff --git a/src/llm/vllm.rs b/src/llm/vllm.rs index 1f32bc65..2133919b 100644 --- a/src/llm/vllm.rs +++ b/src/llm/vllm.rs @@ -4,9 +4,15 @@ use async_openai::config::OpenAIConfig; pub use super::openai::Client; impl Client { - pub async fn new_vllm(address: Option) -> anyhow::Result { + pub async fn new_vllm(address: Option, api_config: Option) -> anyhow::Result { let address = address.unwrap_or_else(|| "http://127.0.0.1:8000/v1".to_string()); - let api_key = std::env::var("VLLM_API_KEY").ok(); + + let api_key = if let Some(super::LlmApiConfig::Vllm(config)) = api_config { + config.api_key.or_else(|| std::env::var("VLLM_API_KEY").ok()) + } else { + std::env::var("VLLM_API_KEY").ok() + }; + let mut config = OpenAIConfig::new().with_api_base(address); if let Some(api_key) = api_key { config = config.with_api_key(api_key); diff --git a/src/llm/voyage.rs b/src/llm/voyage.rs index ea20e7d2..ac7d689e 100644 --- a/src/llm/voyage.rs +++ b/src/llm/voyage.rs @@ -33,14 +33,24 @@ pub struct Client { } impl Client { - pub fn new(address: Option) -> Result { + pub fn new(address: Option, api_config: Option) -> Result { if address.is_some() { api_bail!("Voyage AI doesn't support custom API address"); } - let api_key = match std::env::var("VOYAGE_API_KEY") { - Ok(val) => val, - Err(_) => api_bail!("VOYAGE_API_KEY environment variable must be set"), + + let api_key = if let Some(super::LlmApiConfig::Voyage(config)) = api_config { + if let Some(key) = config.api_key { + key + } else { + std::env::var("VOYAGE_API_KEY").map_err(|_| { + anyhow::anyhow!("VOYAGE_API_KEY environment variable must be set") + })? + } + } else { + std::env::var("VOYAGE_API_KEY") + .map_err(|_| anyhow::anyhow!("VOYAGE_API_KEY environment variable must be set"))? }; + Ok(Self { api_key, client: reqwest::Client::new(), diff --git a/src/ops/functions/embed_text.rs b/src/ops/functions/embed_text.rs index bd870158..b114f282 100644 --- a/src/ops/functions/embed_text.rs +++ b/src/ops/functions/embed_text.rs @@ -13,6 +13,7 @@ struct Spec { api_config: Option, output_dimension: Option, task_type: Option, + api_key: Option, } struct Args { @@ -91,9 +92,58 @@ impl SimpleFunctionFactoryBase for Factory { .next_arg("text")? .expect_type(&ValueType::Basic(BasicValueType::Str))? .required()?; + + // Create API config based on api_key parameter if provided + let api_config = if let Some(api_key) = &spec.api_key { + Some(match spec.api_type { + LlmApiType::OpenAi => { + LlmApiConfig::OpenAi(super::super::super::llm::OpenAiConfig { + org_id: None, + project_id: None, + api_key: Some(api_key.clone()), + }) + } + LlmApiType::Anthropic => { + LlmApiConfig::Anthropic(super::super::super::llm::AnthropicConfig { + api_key: Some(api_key.clone()), + }) + } + LlmApiType::Gemini => { + LlmApiConfig::Gemini(super::super::super::llm::GeminiConfig { + api_key: Some(api_key.clone()), + }) + } + LlmApiType::Voyage => { + LlmApiConfig::Voyage(super::super::super::llm::VoyageConfig { + api_key: Some(api_key.clone()), + }) + } + LlmApiType::LiteLlm => { + LlmApiConfig::LiteLlm(super::super::super::llm::LiteLlmConfig { + api_key: Some(api_key.clone()), + }) + } + LlmApiType::OpenRouter => { + LlmApiConfig::OpenRouter(super::super::super::llm::OpenRouterConfig { + api_key: Some(api_key.clone()), + }) + } + LlmApiType::Vllm => LlmApiConfig::Vllm(super::super::super::llm::VllmConfig { + api_key: Some(api_key.clone()), + }), + _ => spec.api_config.clone().unwrap_or_else(|| { + api_bail!( + "API key parameter is not supported for API type {:?}", + spec.api_type + ) + }), + }) + } else { + spec.api_config.clone() + }; + let client = - new_llm_embedding_client(spec.api_type, spec.address.clone(), spec.api_config.clone()) - .await?; + new_llm_embedding_client(spec.api_type, spec.address.clone(), api_config).await?; let output_dimension = match spec.output_dimension { Some(output_dimension) => output_dimension, None => { @@ -144,6 +194,7 @@ mod tests { api_config: None, output_dimension: None, task_type: None, + api_key: None, }; let factory = Arc::new(Factory); From 0059fd89592c6dba9717a517514a682dd7dc963c Mon Sep 17 00:00:00 2001 From: Arya Soni Date: Sun, 5 Oct 2025 12:06:03 +0530 Subject: [PATCH 03/10] Moved api_key to Common LlmSpec Layer --- python/cocoindex/llm.py | 14 +----- src/llm/anthropic.rs | 24 +++++----- src/llm/gemini.rs | 12 ++--- src/llm/litellm.rs | 8 +--- src/llm/mod.rs | 71 ++++++++++++++--------------- src/llm/openai.rs | 6 +-- src/llm/openrouter.rs | 8 +--- src/llm/vllm.rs | 8 +--- src/llm/voyage.rs | 12 ++--- src/ops/functions/embed_text.rs | 27 +++-------- src/ops/functions/extract_by_llm.rs | 3 ++ 11 files changed, 72 insertions(+), 121 deletions(-) diff --git a/python/cocoindex/llm.py b/python/cocoindex/llm.py index 0a1c42fa..93936e16 100644 --- a/python/cocoindex/llm.py +++ b/python/cocoindex/llm.py @@ -34,7 +34,6 @@ class OpenAiConfig: org_id: str | None = None project_id: str | None = None - api_key: str | None = None @dataclass @@ -43,8 +42,6 @@ class AnthropicConfig: kind = "Anthropic" - api_key: str | None = None - @dataclass class GeminiConfig: @@ -52,8 +49,6 @@ class GeminiConfig: kind = "Gemini" - api_key: str | None = None - @dataclass class VoyageConfig: @@ -61,8 +56,6 @@ class VoyageConfig: kind = "Voyage" - api_key: str | None = None - @dataclass class LiteLlmConfig: @@ -70,8 +63,6 @@ class LiteLlmConfig: kind = "LiteLlm" - api_key: str | None = None - @dataclass class OpenRouterConfig: @@ -79,8 +70,6 @@ class OpenRouterConfig: kind = "OpenRouter" - api_key: str | None = None - @dataclass class VllmConfig: @@ -88,8 +77,6 @@ class VllmConfig: kind = "Vllm" - api_key: str | None = None - @dataclass class LlmSpec: @@ -98,4 +85,5 @@ class LlmSpec: api_type: LlmApiType model: str address: str | None = None + api_key: str | None = None api_config: VertexAiConfig | OpenAiConfig | AnthropicConfig | GeminiConfig | VoyageConfig | LiteLlmConfig | OpenRouterConfig | VllmConfig | None = None diff --git a/src/llm/anthropic.rs b/src/llm/anthropic.rs index 36dbbf61..c252ab98 100644 --- a/src/llm/anthropic.rs +++ b/src/llm/anthropic.rs @@ -14,23 +14,23 @@ pub struct Client { } impl Client { - pub async fn new(address: Option, api_config: Option) -> Result { + pub async fn new( + address: Option, + api_key: Option, + _api_config: Option, + ) -> Result { if address.is_some() { api_bail!("Anthropic doesn't support custom API address"); } - - let api_key = if let Some(super::LlmApiConfig::Anthropic(config)) = api_config { - if let Some(key) = config.api_key { - key - } else { - std::env::var("ANTHROPIC_API_KEY") - .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable must be set"))? - } + + let api_key = if let Some(key) = api_key { + key } else { - std::env::var("ANTHROPIC_API_KEY") - .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable must be set"))? + std::env::var("ANTHROPIC_API_KEY").map_err(|_| { + anyhow::anyhow!("ANTHROPIC_API_KEY environment variable must be set") + })? }; - + Ok(Self { api_key, client: reqwest::Client::new(), diff --git a/src/llm/gemini.rs b/src/llm/gemini.rs index f1c7eb3a..9de8bebb 100644 --- a/src/llm/gemini.rs +++ b/src/llm/gemini.rs @@ -30,18 +30,13 @@ pub struct AiStudioClient { } impl AiStudioClient { - pub fn new(address: Option, api_config: Option) -> Result { + pub fn new(address: Option, api_key: Option, _api_config: Option) -> Result { if address.is_some() { api_bail!("Gemini doesn't support custom API address"); } - let api_key = if let Some(super::LlmApiConfig::Gemini(config)) = api_config { - if let Some(key) = config.api_key { - key - } else { - std::env::var("GEMINI_API_KEY") - .map_err(|_| anyhow::anyhow!("GEMINI_API_KEY environment variable must be set"))? - } + let api_key = if let Some(key) = api_key { + key } else { std::env::var("GEMINI_API_KEY") .map_err(|_| anyhow::anyhow!("GEMINI_API_KEY environment variable must be set"))? @@ -249,6 +244,7 @@ pub struct VertexAiClient { impl VertexAiClient { pub async fn new( address: Option, + _api_key: Option, api_config: Option, ) -> Result { if address.is_some() { diff --git a/src/llm/litellm.rs b/src/llm/litellm.rs index cbca0bc7..65869d8a 100644 --- a/src/llm/litellm.rs +++ b/src/llm/litellm.rs @@ -4,14 +4,10 @@ use async_openai::config::OpenAIConfig; pub use super::openai::Client; impl Client { - pub async fn new_litellm(address: Option, api_config: Option) -> anyhow::Result { + pub async fn new_litellm(address: Option, api_key: Option, _api_config: Option) -> anyhow::Result { let address = address.unwrap_or_else(|| "http://127.0.0.1:4000".to_string()); - let api_key = if let Some(super::LlmApiConfig::LiteLlm(config)) = api_config { - config.api_key.or_else(|| std::env::var("LITELLM_API_KEY").ok()) - } else { - std::env::var("LITELLM_API_KEY").ok() - }; + let api_key = api_key.or_else(|| std::env::var("LITELLM_API_KEY").ok()); let mut config = OpenAIConfig::new().with_api_base(address); if let Some(api_key) = api_key { diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 75497de2..f72dc8eb 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -30,38 +30,25 @@ pub struct VertexAiConfig { pub struct OpenAiConfig { pub org_id: Option, pub project_id: Option, - pub api_key: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AnthropicConfig { - pub api_key: Option, -} +pub struct AnthropicConfig {} #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct GeminiConfig { - pub api_key: Option, -} +pub struct GeminiConfig {} #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct VoyageConfig { - pub api_key: Option, -} +pub struct VoyageConfig {} #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct LiteLlmConfig { - pub api_key: Option, -} +pub struct LiteLlmConfig {} #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct OpenRouterConfig { - pub api_key: Option, -} +pub struct OpenRouterConfig {} #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct VllmConfig { - pub api_key: Option, -} +pub struct VllmConfig {} #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "kind")] @@ -81,6 +68,7 @@ pub struct LlmSpec { pub api_type: LlmApiType, pub address: Option, pub model: String, + pub api_key: Option, pub api_config: Option, } @@ -154,31 +142,37 @@ mod voyage; pub async fn new_llm_generation_client( api_type: LlmApiType, address: Option, + api_key: Option, api_config: Option, ) -> Result> { let client = match api_type { LlmApiType::Ollama => { Box::new(ollama::Client::new(address).await?) as Box } - LlmApiType::OpenAi => { - Box::new(openai::Client::new(address, api_config)?) as Box - } - LlmApiType::Gemini => Box::new(gemini::AiStudioClient::new(address, api_config)?) - as Box, - LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?) + LlmApiType::OpenAi => Box::new(openai::Client::new(address, api_key, api_config)?) as Box, - LlmApiType::Anthropic => Box::new(anthropic::Client::new(address, api_config).await?) - as Box, - LlmApiType::LiteLlm => Box::new(litellm::Client::new_litellm(address, api_config).await?) + LlmApiType::Gemini => Box::new(gemini::AiStudioClient::new(address, api_key, api_config)?) as Box, + LlmApiType::VertexAi => { + Box::new(gemini::VertexAiClient::new(address, api_key, api_config).await?) + as Box + } + LlmApiType::Anthropic => { + Box::new(anthropic::Client::new(address, api_key, api_config).await?) + as Box + } + LlmApiType::LiteLlm => { + Box::new(litellm::Client::new_litellm(address, api_key, api_config).await?) + as Box + } LlmApiType::OpenRouter => { - Box::new(openrouter::Client::new_openrouter(address, api_config).await?) + Box::new(openrouter::Client::new_openrouter(address, api_key, api_config).await?) as Box } LlmApiType::Voyage => { api_bail!("Voyage is not supported for generation") } - LlmApiType::Vllm => Box::new(vllm::Client::new_vllm(address, api_config).await?) + LlmApiType::Vllm => Box::new(vllm::Client::new_vllm(address, api_key, api_config).await?) as Box, }; Ok(client) @@ -187,22 +181,23 @@ pub async fn new_llm_generation_client( pub async fn new_llm_embedding_client( api_type: LlmApiType, address: Option, + api_key: Option, api_config: Option, ) -> Result> { let client = match api_type { LlmApiType::Ollama => { Box::new(ollama::Client::new(address).await?) as Box } - LlmApiType::Gemini => Box::new(gemini::AiStudioClient::new(address, api_config)?) + LlmApiType::Gemini => Box::new(gemini::AiStudioClient::new(address, api_key, api_config)?) as Box, - LlmApiType::OpenAi => { - Box::new(openai::Client::new(address, api_config)?) as Box - } - LlmApiType::Voyage => { - Box::new(voyage::Client::new(address, api_config)?) as Box - } - LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?) + LlmApiType::OpenAi => Box::new(openai::Client::new(address, api_key, api_config)?) + as Box, + LlmApiType::Voyage => Box::new(voyage::Client::new(address, api_key, api_config)?) as Box, + LlmApiType::VertexAi => { + Box::new(gemini::VertexAiClient::new(address, api_key, api_config).await?) + as Box + } LlmApiType::OpenRouter | LlmApiType::LiteLlm | LlmApiType::Vllm | LlmApiType::Anthropic => { api_bail!("Embedding is not supported for API type {:?}", api_type) } diff --git a/src/llm/openai.rs b/src/llm/openai.rs index f4715875..0ab3cd75 100644 --- a/src/llm/openai.rs +++ b/src/llm/openai.rs @@ -33,7 +33,7 @@ impl Client { Self { client } } - pub fn new(address: Option, api_config: Option) -> Result { + pub fn new(address: Option, api_key: Option, api_config: Option) -> Result { let config = match api_config { Some(super::LlmApiConfig::OpenAi(config)) => config, Some(_) => api_bail!("unexpected config type, expected OpenAiConfig"), @@ -50,8 +50,8 @@ impl Client { if let Some(project_id) = config.project_id { openai_config = openai_config.with_project_id(project_id); } - if let Some(api_key) = config.api_key { - openai_config = openai_config.with_api_key(api_key); + if let Some(key) = api_key { + openai_config = openai_config.with_api_key(key); } else { // Verify API key is set in environment if not provided in config if std::env::var("OPENAI_API_KEY").is_err() { diff --git a/src/llm/openrouter.rs b/src/llm/openrouter.rs index 016688e5..db4aad55 100644 --- a/src/llm/openrouter.rs +++ b/src/llm/openrouter.rs @@ -4,14 +4,10 @@ use async_openai::config::OpenAIConfig; pub use super::openai::Client; impl Client { - pub async fn new_openrouter(address: Option, api_config: Option) -> anyhow::Result { + pub async fn new_openrouter(address: Option, api_key: Option, _api_config: Option) -> anyhow::Result { let address = address.unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string()); - let api_key = if let Some(super::LlmApiConfig::OpenRouter(config)) = api_config { - config.api_key.or_else(|| std::env::var("OPENROUTER_API_KEY").ok()) - } else { - std::env::var("OPENROUTER_API_KEY").ok() - }; + let api_key = api_key.or_else(|| std::env::var("OPENROUTER_API_KEY").ok()); let mut config = OpenAIConfig::new().with_api_base(address); if let Some(api_key) = api_key { diff --git a/src/llm/vllm.rs b/src/llm/vllm.rs index 2133919b..d887c931 100644 --- a/src/llm/vllm.rs +++ b/src/llm/vllm.rs @@ -4,14 +4,10 @@ use async_openai::config::OpenAIConfig; pub use super::openai::Client; impl Client { - pub async fn new_vllm(address: Option, api_config: Option) -> anyhow::Result { + pub async fn new_vllm(address: Option, api_key: Option, _api_config: Option) -> anyhow::Result { let address = address.unwrap_or_else(|| "http://127.0.0.1:8000/v1".to_string()); - let api_key = if let Some(super::LlmApiConfig::Vllm(config)) = api_config { - config.api_key.or_else(|| std::env::var("VLLM_API_KEY").ok()) - } else { - std::env::var("VLLM_API_KEY").ok() - }; + let api_key = api_key.or_else(|| std::env::var("VLLM_API_KEY").ok()); let mut config = OpenAIConfig::new().with_api_base(address); if let Some(api_key) = api_key { diff --git a/src/llm/voyage.rs b/src/llm/voyage.rs index ac7d689e..2837526d 100644 --- a/src/llm/voyage.rs +++ b/src/llm/voyage.rs @@ -33,19 +33,13 @@ pub struct Client { } impl Client { - pub fn new(address: Option, api_config: Option) -> Result { + pub fn new(address: Option, api_key: Option, _api_config: Option) -> Result { if address.is_some() { api_bail!("Voyage AI doesn't support custom API address"); } - let api_key = if let Some(super::LlmApiConfig::Voyage(config)) = api_config { - if let Some(key) = config.api_key { - key - } else { - std::env::var("VOYAGE_API_KEY").map_err(|_| { - anyhow::anyhow!("VOYAGE_API_KEY environment variable must be set") - })? - } + let api_key = if let Some(key) = api_key { + key } else { std::env::var("VOYAGE_API_KEY") .map_err(|_| anyhow::anyhow!("VOYAGE_API_KEY environment variable must be set"))? diff --git a/src/ops/functions/embed_text.rs b/src/ops/functions/embed_text.rs index b114f282..c84d8c46 100644 --- a/src/ops/functions/embed_text.rs +++ b/src/ops/functions/embed_text.rs @@ -100,37 +100,24 @@ impl SimpleFunctionFactoryBase for Factory { LlmApiConfig::OpenAi(super::super::super::llm::OpenAiConfig { org_id: None, project_id: None, - api_key: Some(api_key.clone()), }) } LlmApiType::Anthropic => { - LlmApiConfig::Anthropic(super::super::super::llm::AnthropicConfig { - api_key: Some(api_key.clone()), - }) + LlmApiConfig::Anthropic(super::super::super::llm::AnthropicConfig {}) } LlmApiType::Gemini => { - LlmApiConfig::Gemini(super::super::super::llm::GeminiConfig { - api_key: Some(api_key.clone()), - }) + LlmApiConfig::Gemini(super::super::super::llm::GeminiConfig {}) } LlmApiType::Voyage => { - LlmApiConfig::Voyage(super::super::super::llm::VoyageConfig { - api_key: Some(api_key.clone()), - }) + LlmApiConfig::Voyage(super::super::super::llm::VoyageConfig {}) } LlmApiType::LiteLlm => { - LlmApiConfig::LiteLlm(super::super::super::llm::LiteLlmConfig { - api_key: Some(api_key.clone()), - }) + LlmApiConfig::LiteLlm(super::super::super::llm::LiteLlmConfig {}) } LlmApiType::OpenRouter => { - LlmApiConfig::OpenRouter(super::super::super::llm::OpenRouterConfig { - api_key: Some(api_key.clone()), - }) + LlmApiConfig::OpenRouter(super::super::super::llm::OpenRouterConfig {}) } - LlmApiType::Vllm => LlmApiConfig::Vllm(super::super::super::llm::VllmConfig { - api_key: Some(api_key.clone()), - }), + LlmApiType::Vllm => LlmApiConfig::Vllm(super::super::super::llm::VllmConfig {}), _ => spec.api_config.clone().unwrap_or_else(|| { api_bail!( "API key parameter is not supported for API type {:?}", @@ -143,7 +130,7 @@ impl SimpleFunctionFactoryBase for Factory { }; let client = - new_llm_embedding_client(spec.api_type, spec.address.clone(), api_config).await?; + new_llm_embedding_client(spec.api_type, spec.address.clone(), spec.api_key.clone(), api_config).await?; let output_dimension = match spec.output_dimension { Some(output_dimension) => output_dimension, None => { diff --git a/src/ops/functions/extract_by_llm.rs b/src/ops/functions/extract_by_llm.rs index 4dfe9d4d..d929f9ae 100644 --- a/src/ops/functions/extract_by_llm.rs +++ b/src/ops/functions/extract_by_llm.rs @@ -55,6 +55,7 @@ impl Executor { let client = new_llm_generation_client( spec.llm_spec.api_type, spec.llm_spec.address, + spec.llm_spec.api_key, spec.llm_spec.api_config, ) .await?; @@ -204,6 +205,7 @@ mod tests { api_type: crate::llm::LlmApiType::OpenAi, model: "gpt-4o".to_string(), address: None, + api_key: None, api_config: None, }, output_type: output_type_spec, @@ -274,6 +276,7 @@ mod tests { api_type: crate::llm::LlmApiType::OpenAi, model: "gpt-4o".to_string(), address: None, + api_key: None, api_config: None, }, output_type: make_output_type(BasicValueType::Str), From f7b4dcb8438cd3ec85937c7c1d88697ede9ccae7 Mon Sep 17 00:00:00 2001 From: Arya Soni Date: Sun, 5 Oct 2025 12:10:11 +0530 Subject: [PATCH 04/10] Fix formatting issues for GitHub Actions - Fixed Python line length issue in llm.py by breaking long type annotation - Fixed Rust function signature formatting in all LLM client files - Fixed long function call formatting in embed_text.rs - All formatting now complies with project standards --- python/cocoindex/llm.py | 12 +++++++++++- src/llm/gemini.rs | 6 +++++- src/llm/litellm.rs | 6 +++++- src/llm/openai.rs | 6 +++++- src/llm/openrouter.rs | 6 +++++- src/llm/vllm.rs | 6 +++++- src/llm/voyage.rs | 6 +++++- src/ops/functions/embed_text.rs | 9 +++++++-- 8 files changed, 48 insertions(+), 9 deletions(-) diff --git a/python/cocoindex/llm.py b/python/cocoindex/llm.py index 93936e16..e6b12ef9 100644 --- a/python/cocoindex/llm.py +++ b/python/cocoindex/llm.py @@ -86,4 +86,14 @@ class LlmSpec: model: str address: str | None = None api_key: str | None = None - api_config: VertexAiConfig | OpenAiConfig | AnthropicConfig | GeminiConfig | VoyageConfig | LiteLlmConfig | OpenRouterConfig | VllmConfig | None = None + api_config: ( + VertexAiConfig + | OpenAiConfig + | AnthropicConfig + | GeminiConfig + | VoyageConfig + | LiteLlmConfig + | OpenRouterConfig + | VllmConfig + | None + ) = None diff --git a/src/llm/gemini.rs b/src/llm/gemini.rs index 9de8bebb..d1f3f593 100644 --- a/src/llm/gemini.rs +++ b/src/llm/gemini.rs @@ -30,7 +30,11 @@ pub struct AiStudioClient { } impl AiStudioClient { - pub fn new(address: Option, api_key: Option, _api_config: Option) -> Result { + pub fn new( + address: Option, + api_key: Option, + _api_config: Option, + ) -> Result { if address.is_some() { api_bail!("Gemini doesn't support custom API address"); } diff --git a/src/llm/litellm.rs b/src/llm/litellm.rs index 65869d8a..31b51e5c 100644 --- a/src/llm/litellm.rs +++ b/src/llm/litellm.rs @@ -4,7 +4,11 @@ use async_openai::config::OpenAIConfig; pub use super::openai::Client; impl Client { - pub async fn new_litellm(address: Option, api_key: Option, _api_config: Option) -> anyhow::Result { + pub async fn new_litellm( + address: Option, + api_key: Option, + _api_config: Option, + ) -> anyhow::Result { let address = address.unwrap_or_else(|| "http://127.0.0.1:4000".to_string()); let api_key = api_key.or_else(|| std::env::var("LITELLM_API_KEY").ok()); diff --git a/src/llm/openai.rs b/src/llm/openai.rs index 0ab3cd75..ac5c4113 100644 --- a/src/llm/openai.rs +++ b/src/llm/openai.rs @@ -33,7 +33,11 @@ impl Client { Self { client } } - pub fn new(address: Option, api_key: Option, api_config: Option) -> Result { + pub fn new( + address: Option, + api_key: Option, + api_config: Option, + ) -> Result { let config = match api_config { Some(super::LlmApiConfig::OpenAi(config)) => config, Some(_) => api_bail!("unexpected config type, expected OpenAiConfig"), diff --git a/src/llm/openrouter.rs b/src/llm/openrouter.rs index db4aad55..350174e4 100644 --- a/src/llm/openrouter.rs +++ b/src/llm/openrouter.rs @@ -4,7 +4,11 @@ use async_openai::config::OpenAIConfig; pub use super::openai::Client; impl Client { - pub async fn new_openrouter(address: Option, api_key: Option, _api_config: Option) -> anyhow::Result { + pub async fn new_openrouter( + address: Option, + api_key: Option, + _api_config: Option, + ) -> anyhow::Result { let address = address.unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string()); let api_key = api_key.or_else(|| std::env::var("OPENROUTER_API_KEY").ok()); diff --git a/src/llm/vllm.rs b/src/llm/vllm.rs index d887c931..67518f22 100644 --- a/src/llm/vllm.rs +++ b/src/llm/vllm.rs @@ -4,7 +4,11 @@ use async_openai::config::OpenAIConfig; pub use super::openai::Client; impl Client { - pub async fn new_vllm(address: Option, api_key: Option, _api_config: Option) -> anyhow::Result { + pub async fn new_vllm( + address: Option, + api_key: Option, + _api_config: Option, + ) -> anyhow::Result { let address = address.unwrap_or_else(|| "http://127.0.0.1:8000/v1".to_string()); let api_key = api_key.or_else(|| std::env::var("VLLM_API_KEY").ok()); diff --git a/src/llm/voyage.rs b/src/llm/voyage.rs index 2837526d..75aa4ac8 100644 --- a/src/llm/voyage.rs +++ b/src/llm/voyage.rs @@ -33,7 +33,11 @@ pub struct Client { } impl Client { - pub fn new(address: Option, api_key: Option, _api_config: Option) -> Result { + pub fn new( + address: Option, + api_key: Option, + _api_config: Option, + ) -> Result { if address.is_some() { api_bail!("Voyage AI doesn't support custom API address"); } diff --git a/src/ops/functions/embed_text.rs b/src/ops/functions/embed_text.rs index c84d8c46..d518fdd6 100644 --- a/src/ops/functions/embed_text.rs +++ b/src/ops/functions/embed_text.rs @@ -129,8 +129,13 @@ impl SimpleFunctionFactoryBase for Factory { spec.api_config.clone() }; - let client = - new_llm_embedding_client(spec.api_type, spec.address.clone(), spec.api_key.clone(), api_config).await?; + let client = new_llm_embedding_client( + spec.api_type, + spec.address.clone(), + spec.api_key.clone(), + api_config, + ) + .await?; let output_dimension = match spec.output_dimension { Some(output_dimension) => output_dimension, None => { From f48b50392162c5e8b7e8b53701138109ddaa6275 Mon Sep 17 00:00:00 2001 From: Arya Soni Date: Sun, 5 Oct 2025 12:10:53 +0530 Subject: [PATCH 05/10] Fix type mismatch error in embed_text.rs - Fixed api_bail! usage in context expecting LlmApiConfig return type - Replaced unwrap_or_else with proper if-let pattern matching - Resolves compilation error in GitHub Actions build test --- src/ops/functions/embed_text.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/ops/functions/embed_text.rs b/src/ops/functions/embed_text.rs index d518fdd6..4f4088d7 100644 --- a/src/ops/functions/embed_text.rs +++ b/src/ops/functions/embed_text.rs @@ -118,12 +118,16 @@ impl SimpleFunctionFactoryBase for Factory { LlmApiConfig::OpenRouter(super::super::super::llm::OpenRouterConfig {}) } LlmApiType::Vllm => LlmApiConfig::Vllm(super::super::super::llm::VllmConfig {}), - _ => spec.api_config.clone().unwrap_or_else(|| { - api_bail!( - "API key parameter is not supported for API type {:?}", - spec.api_type - ) - }), + _ => { + if let Some(config) = spec.api_config.clone() { + config + } else { + api_bail!( + "API key parameter is not supported for API type {:?}", + spec.api_type + ) + } + }, }) } else { spec.api_config.clone() From 059ae69336c3bf39e108beafe8cd9816bf522a57 Mon Sep 17 00:00:00 2001 From: Arya Soni Date: Sun, 5 Oct 2025 12:12:50 +0530 Subject: [PATCH 06/10] Fix trailing whitespace formatting issues - Removed trailing whitespace from all LLM client files - Fixed formatting issues in gemini.rs, litellm.rs, openai.rs, openrouter.rs, vllm.rs - Fixed trailing whitespace in embed_text.rs - All files now comply with cargo fmt standards --- src/llm/gemini.rs | 4 ++-- src/llm/litellm.rs | 4 ++-- src/llm/openai.rs | 2 +- src/llm/openrouter.rs | 4 ++-- src/llm/vllm.rs | 4 ++-- src/ops/functions/embed_text.rs | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/llm/gemini.rs b/src/llm/gemini.rs index d1f3f593..dde11581 100644 --- a/src/llm/gemini.rs +++ b/src/llm/gemini.rs @@ -38,14 +38,14 @@ impl AiStudioClient { if address.is_some() { api_bail!("Gemini doesn't support custom API address"); } - + let api_key = if let Some(key) = api_key { key } else { std::env::var("GEMINI_API_KEY") .map_err(|_| anyhow::anyhow!("GEMINI_API_KEY environment variable must be set"))? }; - + Ok(Self { api_key, client: reqwest::Client::new(), diff --git a/src/llm/litellm.rs b/src/llm/litellm.rs index 31b51e5c..96fb3761 100644 --- a/src/llm/litellm.rs +++ b/src/llm/litellm.rs @@ -10,9 +10,9 @@ impl Client { _api_config: Option, ) -> anyhow::Result { let address = address.unwrap_or_else(|| "http://127.0.0.1:4000".to_string()); - + let api_key = api_key.or_else(|| std::env::var("LITELLM_API_KEY").ok()); - + let mut config = OpenAIConfig::new().with_api_base(address); if let Some(api_key) = api_key { config = config.with_api_key(api_key); diff --git a/src/llm/openai.rs b/src/llm/openai.rs index ac5c4113..ea76040c 100644 --- a/src/llm/openai.rs +++ b/src/llm/openai.rs @@ -62,7 +62,7 @@ impl Client { api_bail!("OPENAI_API_KEY environment variable must be set"); } } - + Ok(Self { client: OpenAIClient::with_config(openai_config), }) diff --git a/src/llm/openrouter.rs b/src/llm/openrouter.rs index 350174e4..1c7b5655 100644 --- a/src/llm/openrouter.rs +++ b/src/llm/openrouter.rs @@ -10,9 +10,9 @@ impl Client { _api_config: Option, ) -> anyhow::Result { let address = address.unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string()); - + let api_key = api_key.or_else(|| std::env::var("OPENROUTER_API_KEY").ok()); - + let mut config = OpenAIConfig::new().with_api_base(address); if let Some(api_key) = api_key { config = config.with_api_key(api_key); diff --git a/src/llm/vllm.rs b/src/llm/vllm.rs index 67518f22..3167e582 100644 --- a/src/llm/vllm.rs +++ b/src/llm/vllm.rs @@ -10,9 +10,9 @@ impl Client { _api_config: Option, ) -> anyhow::Result { let address = address.unwrap_or_else(|| "http://127.0.0.1:8000/v1".to_string()); - + let api_key = api_key.or_else(|| std::env::var("VLLM_API_KEY").ok()); - + let mut config = OpenAIConfig::new().with_api_base(address); if let Some(api_key) = api_key { config = config.with_api_key(api_key); diff --git a/src/ops/functions/embed_text.rs b/src/ops/functions/embed_text.rs index 4f4088d7..1028f69d 100644 --- a/src/ops/functions/embed_text.rs +++ b/src/ops/functions/embed_text.rs @@ -127,7 +127,7 @@ impl SimpleFunctionFactoryBase for Factory { spec.api_type ) } - }, + } }) } else { spec.api_config.clone() From 26b57889678261b83069138e82303b7a5a72c1a9 Mon Sep 17 00:00:00 2001 From: Arya Soni Date: Thu, 9 Oct 2025 10:24:58 +0530 Subject: [PATCH 07/10] Feature: Enable programmatically pass in api_key besides reading from env --- python/cocoindex/llm.py | 48 ----------------------------------------- src/llm/mod.rs | 24 --------------------- 2 files changed, 72 deletions(-) diff --git a/python/cocoindex/llm.py b/python/cocoindex/llm.py index e6b12ef9..3cf328d2 100644 --- a/python/cocoindex/llm.py +++ b/python/cocoindex/llm.py @@ -36,48 +36,6 @@ class OpenAiConfig: project_id: str | None = None -@dataclass -class AnthropicConfig: - """A specification for an Anthropic LLM.""" - - kind = "Anthropic" - - -@dataclass -class GeminiConfig: - """A specification for a Gemini LLM.""" - - kind = "Gemini" - - -@dataclass -class VoyageConfig: - """A specification for a Voyage LLM.""" - - kind = "Voyage" - - -@dataclass -class LiteLlmConfig: - """A specification for a LiteLLM LLM.""" - - kind = "LiteLlm" - - -@dataclass -class OpenRouterConfig: - """A specification for an OpenRouter LLM.""" - - kind = "OpenRouter" - - -@dataclass -class VllmConfig: - """A specification for a VLLM LLM.""" - - kind = "Vllm" - - @dataclass class LlmSpec: """A specification for a LLM.""" @@ -89,11 +47,5 @@ class LlmSpec: api_config: ( VertexAiConfig | OpenAiConfig - | AnthropicConfig - | GeminiConfig - | VoyageConfig - | LiteLlmConfig - | OpenRouterConfig - | VllmConfig | None ) = None diff --git a/src/llm/mod.rs b/src/llm/mod.rs index f72dc8eb..d7dde55a 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -32,35 +32,11 @@ pub struct OpenAiConfig { pub project_id: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AnthropicConfig {} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct GeminiConfig {} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct VoyageConfig {} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct LiteLlmConfig {} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct OpenRouterConfig {} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct VllmConfig {} - #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "kind")] pub enum LlmApiConfig { VertexAi(VertexAiConfig), OpenAi(OpenAiConfig), - Anthropic(AnthropicConfig), - Gemini(GeminiConfig), - Voyage(VoyageConfig), - LiteLlm(LiteLlmConfig), - OpenRouter(OpenRouterConfig), - Vllm(VllmConfig), } #[derive(Debug, Clone, Serialize, Deserialize)] From ccb1b0fa99f707d89f6e7bfa59f3affe6784ab62 Mon Sep 17 00:00:00 2001 From: Arya Soni Date: Thu, 9 Oct 2025 10:31:49 +0530 Subject: [PATCH 08/10] Feature: Enable programmatically pass in api_key besides reading from env --- python/cocoindex/llm.py | 6 +---- src/ops/functions/embed_text.rs | 39 ++++++++++++++------------------- 2 files changed, 17 insertions(+), 28 deletions(-) diff --git a/python/cocoindex/llm.py b/python/cocoindex/llm.py index 3cf328d2..3f7eefe7 100644 --- a/python/cocoindex/llm.py +++ b/python/cocoindex/llm.py @@ -44,8 +44,4 @@ class LlmSpec: model: str address: str | None = None api_key: str | None = None - api_config: ( - VertexAiConfig - | OpenAiConfig - | None - ) = None + api_config: VertexAiConfig | OpenAiConfig | None = None diff --git a/src/ops/functions/embed_text.rs b/src/ops/functions/embed_text.rs index 1028f69d..dd2f5714 100644 --- a/src/ops/functions/embed_text.rs +++ b/src/ops/functions/embed_text.rs @@ -94,33 +94,26 @@ impl SimpleFunctionFactoryBase for Factory { .required()?; // Create API config based on api_key parameter if provided - let api_config = if let Some(api_key) = &spec.api_key { - Some(match spec.api_type { - LlmApiType::OpenAi => { - LlmApiConfig::OpenAi(super::super::super::llm::OpenAiConfig { + let api_config = if let Some(_api_key) = &spec.api_key { + match spec.api_type { + LlmApiType::OpenAi => Some(LlmApiConfig::OpenAi( + super::super::super::llm::OpenAiConfig { org_id: None, project_id: None, - }) + }, + )), + LlmApiType::Anthropic + | LlmApiType::Gemini + | LlmApiType::Voyage + | LlmApiType::LiteLlm + | LlmApiType::OpenRouter + | LlmApiType::Vllm => { + // These API types don't require a config, just an API key + None } - LlmApiType::Anthropic => { - LlmApiConfig::Anthropic(super::super::super::llm::AnthropicConfig {}) - } - LlmApiType::Gemini => { - LlmApiConfig::Gemini(super::super::super::llm::GeminiConfig {}) - } - LlmApiType::Voyage => { - LlmApiConfig::Voyage(super::super::super::llm::VoyageConfig {}) - } - LlmApiType::LiteLlm => { - LlmApiConfig::LiteLlm(super::super::super::llm::LiteLlmConfig {}) - } - LlmApiType::OpenRouter => { - LlmApiConfig::OpenRouter(super::super::super::llm::OpenRouterConfig {}) - } - LlmApiType::Vllm => LlmApiConfig::Vllm(super::super::super::llm::VllmConfig {}), _ => { if let Some(config) = spec.api_config.clone() { - config + Some(config) } else { api_bail!( "API key parameter is not supported for API type {:?}", @@ -128,7 +121,7 @@ impl SimpleFunctionFactoryBase for Factory { ) } } - }) + } } else { spec.api_config.clone() }; From f1116158ff3ff6bfc75709cd391e8a091865401b Mon Sep 17 00:00:00 2001 From: Arya Soni Date: Fri, 10 Oct 2025 11:52:20 +0530 Subject: [PATCH 09/10] Feature: Enable programmatically pass in api_key besides reading from env --- src/llm/anthropic.rs | 6 +----- src/llm/gemini.rs | 6 +----- src/llm/litellm.rs | 1 - src/llm/mod.rs | 14 ++++++------- src/llm/openrouter.rs | 1 - src/llm/vllm.rs | 1 - src/llm/voyage.rs | 6 +----- src/ops/functions/embed_text.rs | 35 +-------------------------------- 8 files changed, 11 insertions(+), 59 deletions(-) diff --git a/src/llm/anthropic.rs b/src/llm/anthropic.rs index c252ab98..1c4755ce 100644 --- a/src/llm/anthropic.rs +++ b/src/llm/anthropic.rs @@ -14,11 +14,7 @@ pub struct Client { } impl Client { - pub async fn new( - address: Option, - api_key: Option, - _api_config: Option, - ) -> Result { + pub async fn new(address: Option, api_key: Option) -> Result { if address.is_some() { api_bail!("Anthropic doesn't support custom API address"); } diff --git a/src/llm/gemini.rs b/src/llm/gemini.rs index 18ac0c4e..4d23c392 100644 --- a/src/llm/gemini.rs +++ b/src/llm/gemini.rs @@ -34,11 +34,7 @@ pub struct AiStudioClient { } impl AiStudioClient { - pub fn new( - address: Option, - api_key: Option, - _api_config: Option, - ) -> Result { + pub fn new(address: Option, api_key: Option) -> Result { if address.is_some() { api_bail!("Gemini doesn't support custom API address"); } diff --git a/src/llm/litellm.rs b/src/llm/litellm.rs index 96fb3761..c2503dd7 100644 --- a/src/llm/litellm.rs +++ b/src/llm/litellm.rs @@ -7,7 +7,6 @@ impl Client { pub async fn new_litellm( address: Option, api_key: Option, - _api_config: Option, ) -> anyhow::Result { let address = address.unwrap_or_else(|| "http://127.0.0.1:4000".to_string()); diff --git a/src/llm/mod.rs b/src/llm/mod.rs index d7dde55a..ee10e484 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -127,28 +127,28 @@ pub async fn new_llm_generation_client( } LlmApiType::OpenAi => Box::new(openai::Client::new(address, api_key, api_config)?) as Box, - LlmApiType::Gemini => Box::new(gemini::AiStudioClient::new(address, api_key, api_config)?) + LlmApiType::Gemini => Box::new(gemini::AiStudioClient::new(address, api_key)?) as Box, LlmApiType::VertexAi => { Box::new(gemini::VertexAiClient::new(address, api_key, api_config).await?) as Box } LlmApiType::Anthropic => { - Box::new(anthropic::Client::new(address, api_key, api_config).await?) + Box::new(anthropic::Client::new(address, api_key).await?) as Box } LlmApiType::LiteLlm => { - Box::new(litellm::Client::new_litellm(address, api_key, api_config).await?) + Box::new(litellm::Client::new_litellm(address, api_key).await?) as Box } LlmApiType::OpenRouter => { - Box::new(openrouter::Client::new_openrouter(address, api_key, api_config).await?) + Box::new(openrouter::Client::new_openrouter(address, api_key).await?) as Box } LlmApiType::Voyage => { api_bail!("Voyage is not supported for generation") } - LlmApiType::Vllm => Box::new(vllm::Client::new_vllm(address, api_key, api_config).await?) + LlmApiType::Vllm => Box::new(vllm::Client::new_vllm(address, api_key).await?) as Box, }; Ok(client) @@ -164,11 +164,11 @@ pub async fn new_llm_embedding_client( LlmApiType::Ollama => { Box::new(ollama::Client::new(address).await?) as Box } - LlmApiType::Gemini => Box::new(gemini::AiStudioClient::new(address, api_key, api_config)?) + LlmApiType::Gemini => Box::new(gemini::AiStudioClient::new(address, api_key)?) as Box, LlmApiType::OpenAi => Box::new(openai::Client::new(address, api_key, api_config)?) as Box, - LlmApiType::Voyage => Box::new(voyage::Client::new(address, api_key, api_config)?) + LlmApiType::Voyage => Box::new(voyage::Client::new(address, api_key)?) as Box, LlmApiType::VertexAi => { Box::new(gemini::VertexAiClient::new(address, api_key, api_config).await?) diff --git a/src/llm/openrouter.rs b/src/llm/openrouter.rs index 1c7b5655..9298cdbc 100644 --- a/src/llm/openrouter.rs +++ b/src/llm/openrouter.rs @@ -7,7 +7,6 @@ impl Client { pub async fn new_openrouter( address: Option, api_key: Option, - _api_config: Option, ) -> anyhow::Result { let address = address.unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string()); diff --git a/src/llm/vllm.rs b/src/llm/vllm.rs index 3167e582..c7528802 100644 --- a/src/llm/vllm.rs +++ b/src/llm/vllm.rs @@ -7,7 +7,6 @@ impl Client { pub async fn new_vllm( address: Option, api_key: Option, - _api_config: Option, ) -> anyhow::Result { let address = address.unwrap_or_else(|| "http://127.0.0.1:8000/v1".to_string()); diff --git a/src/llm/voyage.rs b/src/llm/voyage.rs index 75aa4ac8..dbff8af1 100644 --- a/src/llm/voyage.rs +++ b/src/llm/voyage.rs @@ -33,11 +33,7 @@ pub struct Client { } impl Client { - pub fn new( - address: Option, - api_key: Option, - _api_config: Option, - ) -> Result { + pub fn new(address: Option, api_key: Option) -> Result { if address.is_some() { api_bail!("Voyage AI doesn't support custom API address"); } diff --git a/src/ops/functions/embed_text.rs b/src/ops/functions/embed_text.rs index dd2f5714..825d98fa 100644 --- a/src/ops/functions/embed_text.rs +++ b/src/ops/functions/embed_text.rs @@ -93,44 +93,11 @@ impl SimpleFunctionFactoryBase for Factory { .expect_type(&ValueType::Basic(BasicValueType::Str))? .required()?; - // Create API config based on api_key parameter if provided - let api_config = if let Some(_api_key) = &spec.api_key { - match spec.api_type { - LlmApiType::OpenAi => Some(LlmApiConfig::OpenAi( - super::super::super::llm::OpenAiConfig { - org_id: None, - project_id: None, - }, - )), - LlmApiType::Anthropic - | LlmApiType::Gemini - | LlmApiType::Voyage - | LlmApiType::LiteLlm - | LlmApiType::OpenRouter - | LlmApiType::Vllm => { - // These API types don't require a config, just an API key - None - } - _ => { - if let Some(config) = spec.api_config.clone() { - Some(config) - } else { - api_bail!( - "API key parameter is not supported for API type {:?}", - spec.api_type - ) - } - } - } - } else { - spec.api_config.clone() - }; - let client = new_llm_embedding_client( spec.api_type, spec.address.clone(), spec.api_key.clone(), - api_config, + spec.api_config.clone(), ) .await?; let output_dimension = match spec.output_dimension { From 86269ba317b46564be040136d67a91fedecbed29 Mon Sep 17 00:00:00 2001 From: Arya Soni Date: Fri, 10 Oct 2025 11:57:06 +0530 Subject: [PATCH 10/10] Feature: Enable programmatically pass in api_key besides reading from env --- src/llm/mod.rs | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/src/llm/mod.rs b/src/llm/mod.rs index ee10e484..5c84980d 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -127,20 +127,17 @@ pub async fn new_llm_generation_client( } LlmApiType::OpenAi => Box::new(openai::Client::new(address, api_key, api_config)?) as Box, - LlmApiType::Gemini => Box::new(gemini::AiStudioClient::new(address, api_key)?) - as Box, + LlmApiType::Gemini => { + Box::new(gemini::AiStudioClient::new(address, api_key)?) as Box + } LlmApiType::VertexAi => { Box::new(gemini::VertexAiClient::new(address, api_key, api_config).await?) as Box } - LlmApiType::Anthropic => { - Box::new(anthropic::Client::new(address, api_key).await?) - as Box - } - LlmApiType::LiteLlm => { - Box::new(litellm::Client::new_litellm(address, api_key).await?) - as Box - } + LlmApiType::Anthropic => Box::new(anthropic::Client::new(address, api_key).await?) + as Box, + LlmApiType::LiteLlm => Box::new(litellm::Client::new_litellm(address, api_key).await?) + as Box, LlmApiType::OpenRouter => { Box::new(openrouter::Client::new_openrouter(address, api_key).await?) as Box @@ -164,12 +161,14 @@ pub async fn new_llm_embedding_client( LlmApiType::Ollama => { Box::new(ollama::Client::new(address).await?) as Box } - LlmApiType::Gemini => Box::new(gemini::AiStudioClient::new(address, api_key)?) - as Box, + LlmApiType::Gemini => { + Box::new(gemini::AiStudioClient::new(address, api_key)?) as Box + } LlmApiType::OpenAi => Box::new(openai::Client::new(address, api_key, api_config)?) as Box, - LlmApiType::Voyage => Box::new(voyage::Client::new(address, api_key)?) - as Box, + LlmApiType::Voyage => { + Box::new(voyage::Client::new(address, api_key)?) as Box + } LlmApiType::VertexAi => { Box::new(gemini::VertexAiClient::new(address, api_key, api_config).await?) as Box