diff --git a/python/cocoindex/functions.py b/python/cocoindex/functions.py index 13765907..64955d4f 100644 --- a/python/cocoindex/functions.py +++ b/python/cocoindex/functions.py @@ -67,6 +67,7 @@ class EmbedText(op.FunctionSpec): output_dimension: int | None = None task_type: str | None = None api_config: llm.VertexAiConfig | None = None + api_key: str | None = None class ExtractByLlm(op.FunctionSpec): diff --git a/python/cocoindex/functions/_engine_builtin_specs.py b/python/cocoindex/functions/_engine_builtin_specs.py index e3948c63..f7564164 100644 --- a/python/cocoindex/functions/_engine_builtin_specs.py +++ b/python/cocoindex/functions/_engine_builtin_specs.py @@ -56,6 +56,7 @@ class EmbedText(op.FunctionSpec): output_dimension: int | None = None task_type: str | None = None api_config: llm.VertexAiConfig | None = None + api_key: str | None = None class ExtractByLlm(op.FunctionSpec): diff --git a/python/cocoindex/llm.py b/python/cocoindex/llm.py index 3f12c90a..3ffe393b 100644 --- a/python/cocoindex/llm.py +++ b/python/cocoindex/llm.py @@ -44,4 +44,5 @@ class LlmSpec: api_type: LlmApiType model: str address: str | None = None + api_key: str | None = None api_config: VertexAiConfig | OpenAiConfig | None = None diff --git a/src/llm/anthropic.rs b/src/llm/anthropic.rs index d81f5d76..1c4755ce 100644 --- a/src/llm/anthropic.rs +++ b/src/llm/anthropic.rs @@ -14,14 +14,19 @@ pub struct Client { } impl Client { - pub async fn new(address: Option) -> Result { + pub async fn new(address: Option, api_key: Option) -> Result { if address.is_some() { api_bail!("Anthropic doesn't support custom API address"); } - let api_key = match std::env::var("ANTHROPIC_API_KEY") { - Ok(val) => val, - Err(_) => api_bail!("ANTHROPIC_API_KEY environment variable must be set"), + + let api_key = if let Some(key) = api_key { + key + } else { + std::env::var("ANTHROPIC_API_KEY").map_err(|_| { + anyhow::anyhow!("ANTHROPIC_API_KEY environment variable must be set") + })? }; + Ok(Self { api_key, client: reqwest::Client::new(), diff --git a/src/llm/gemini.rs b/src/llm/gemini.rs index 48461b4c..4d23c392 100644 --- a/src/llm/gemini.rs +++ b/src/llm/gemini.rs @@ -34,14 +34,18 @@ pub struct AiStudioClient { } impl AiStudioClient { - pub fn new(address: Option) -> Result { + pub fn new(address: Option, api_key: Option) -> Result { if address.is_some() { api_bail!("Gemini doesn't support custom API address"); } - let api_key = match std::env::var("GEMINI_API_KEY") { - Ok(val) => val, - Err(_) => api_bail!("GEMINI_API_KEY environment variable must be set"), + + let api_key = if let Some(key) = api_key { + key + } else { + std::env::var("GEMINI_API_KEY") + .map_err(|_| anyhow::anyhow!("GEMINI_API_KEY environment variable must be set"))? }; + Ok(Self { api_key, client: reqwest::Client::new(), @@ -271,6 +275,7 @@ static SHARED_RETRY_THROTTLER: LazyLock = impl VertexAiClient { pub async fn new( address: Option, + _api_key: Option, api_config: Option, ) -> Result { if address.is_some() { diff --git a/src/llm/litellm.rs b/src/llm/litellm.rs index 85d1b50e..c2503dd7 100644 --- a/src/llm/litellm.rs +++ b/src/llm/litellm.rs @@ -4,9 +4,14 @@ use async_openai::config::OpenAIConfig; pub use super::openai::Client; impl Client { - pub async fn new_litellm(address: Option) -> anyhow::Result { + pub async fn new_litellm( + address: Option, + api_key: Option, + ) -> anyhow::Result { let address = address.unwrap_or_else(|| "http://127.0.0.1:4000".to_string()); - let api_key = std::env::var("LITELLM_API_KEY").ok(); + + let api_key = api_key.or_else(|| std::env::var("LITELLM_API_KEY").ok()); + let mut config = OpenAIConfig::new().with_api_base(address); if let Some(api_key) = api_key { config = config.with_api_key(api_key); diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 12eda662..d4a27b04 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -45,6 +45,7 @@ pub struct LlmSpec { pub api_type: LlmApiType, pub address: Option, pub model: String, + pub api_key: Option, pub api_config: Option, } @@ -119,37 +120,38 @@ mod voyage; pub async fn new_llm_generation_client( api_type: LlmApiType, address: Option, + api_key: Option, api_config: Option, ) -> Result> { let client = match api_type { LlmApiType::Ollama => { Box::new(ollama::Client::new(address).await?) as Box } - LlmApiType::OpenAi => { - Box::new(openai::Client::new(address, api_config)?) as Box - } + LlmApiType::OpenAi => Box::new(openai::Client::new(address, api_key, api_config)?) + as Box, LlmApiType::Gemini => { - Box::new(gemini::AiStudioClient::new(address)?) as Box + Box::new(gemini::AiStudioClient::new(address, api_key)?) as Box } - LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?) - as Box, - LlmApiType::Anthropic => { - Box::new(anthropic::Client::new(address).await?) as Box + LlmApiType::VertexAi => { + Box::new(gemini::VertexAiClient::new(address, api_key, api_config).await?) + as Box } + LlmApiType::Anthropic => Box::new(anthropic::Client::new(address, api_key).await?) + as Box, LlmApiType::Bedrock => { Box::new(bedrock::Client::new(address).await?) as Box } - LlmApiType::LiteLlm => { - Box::new(litellm::Client::new_litellm(address).await?) as Box - } - LlmApiType::OpenRouter => Box::new(openrouter::Client::new_openrouter(address).await?) + LlmApiType::LiteLlm => Box::new(litellm::Client::new_litellm(address, api_key).await?) as Box, + LlmApiType::OpenRouter => { + Box::new(openrouter::Client::new_openrouter(address, api_key).await?) + as Box + } LlmApiType::Voyage => { api_bail!("Voyage is not supported for generation") } - LlmApiType::Vllm => { - Box::new(vllm::Client::new_vllm(address).await?) as Box - } + LlmApiType::Vllm => Box::new(vllm::Client::new_vllm(address, api_key).await?) + as Box, }; Ok(client) } @@ -157,6 +159,7 @@ pub async fn new_llm_generation_client( pub async fn new_llm_embedding_client( api_type: LlmApiType, address: Option, + api_key: Option, api_config: Option, ) -> Result> { let client = match api_type { @@ -164,16 +167,17 @@ pub async fn new_llm_embedding_client( Box::new(ollama::Client::new(address).await?) as Box } LlmApiType::Gemini => { - Box::new(gemini::AiStudioClient::new(address)?) as Box - } - LlmApiType::OpenAi => { - Box::new(openai::Client::new(address, api_config)?) as Box + Box::new(gemini::AiStudioClient::new(address, api_key)?) as Box } + LlmApiType::OpenAi => Box::new(openai::Client::new(address, api_key, api_config)?) + as Box, LlmApiType::Voyage => { - Box::new(voyage::Client::new(address)?) as Box + Box::new(voyage::Client::new(address, api_key)?) as Box + } + LlmApiType::VertexAi => { + Box::new(gemini::VertexAiClient::new(address, api_key, api_config).await?) + as Box } - LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?) - as Box, LlmApiType::OpenRouter | LlmApiType::LiteLlm | LlmApiType::Vllm diff --git a/src/llm/openai.rs b/src/llm/openai.rs index 68ec6421..6a2c7f1d 100644 --- a/src/llm/openai.rs +++ b/src/llm/openai.rs @@ -32,7 +32,11 @@ impl Client { Self { client } } - pub fn new(address: Option, api_config: Option) -> Result { + pub fn new( + address: Option, + api_key: Option, + api_config: Option, + ) -> Result { let config = match api_config { Some(super::LlmApiConfig::OpenAi(config)) => config, Some(_) => api_bail!("unexpected config type, expected OpenAiConfig"), @@ -49,13 +53,16 @@ impl Client { if let Some(project_id) = config.project_id { openai_config = openai_config.with_project_id(project_id); } - - // Verify API key is set - if std::env::var("OPENAI_API_KEY").is_err() { - api_bail!("OPENAI_API_KEY environment variable must be set"); + if let Some(key) = api_key { + openai_config = openai_config.with_api_key(key); + } else { + // Verify API key is set in environment if not provided in config + if std::env::var("OPENAI_API_KEY").is_err() { + api_bail!("OPENAI_API_KEY environment variable must be set"); + } } + Ok(Self { - // OpenAI client will use OPENAI_API_KEY and OPENAI_API_BASE env variables by default client: OpenAIClient::with_config(openai_config), }) } diff --git a/src/llm/openrouter.rs b/src/llm/openrouter.rs index ecf4d0fa..9298cdbc 100644 --- a/src/llm/openrouter.rs +++ b/src/llm/openrouter.rs @@ -4,9 +4,14 @@ use async_openai::config::OpenAIConfig; pub use super::openai::Client; impl Client { - pub async fn new_openrouter(address: Option) -> anyhow::Result { + pub async fn new_openrouter( + address: Option, + api_key: Option, + ) -> anyhow::Result { let address = address.unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string()); - let api_key = std::env::var("OPENROUTER_API_KEY").ok(); + + let api_key = api_key.or_else(|| std::env::var("OPENROUTER_API_KEY").ok()); + let mut config = OpenAIConfig::new().with_api_base(address); if let Some(api_key) = api_key { config = config.with_api_key(api_key); diff --git a/src/llm/vllm.rs b/src/llm/vllm.rs index 1f32bc65..c7528802 100644 --- a/src/llm/vllm.rs +++ b/src/llm/vllm.rs @@ -4,9 +4,14 @@ use async_openai::config::OpenAIConfig; pub use super::openai::Client; impl Client { - pub async fn new_vllm(address: Option) -> anyhow::Result { + pub async fn new_vllm( + address: Option, + api_key: Option, + ) -> anyhow::Result { let address = address.unwrap_or_else(|| "http://127.0.0.1:8000/v1".to_string()); - let api_key = std::env::var("VLLM_API_KEY").ok(); + + let api_key = api_key.or_else(|| std::env::var("VLLM_API_KEY").ok()); + let mut config = OpenAIConfig::new().with_api_base(address); if let Some(api_key) = api_key { config = config.with_api_key(api_key); diff --git a/src/llm/voyage.rs b/src/llm/voyage.rs index ea20e7d2..dbff8af1 100644 --- a/src/llm/voyage.rs +++ b/src/llm/voyage.rs @@ -33,14 +33,18 @@ pub struct Client { } impl Client { - pub fn new(address: Option) -> Result { + pub fn new(address: Option, api_key: Option) -> Result { if address.is_some() { api_bail!("Voyage AI doesn't support custom API address"); } - let api_key = match std::env::var("VOYAGE_API_KEY") { - Ok(val) => val, - Err(_) => api_bail!("VOYAGE_API_KEY environment variable must be set"), + + let api_key = if let Some(key) = api_key { + key + } else { + std::env::var("VOYAGE_API_KEY") + .map_err(|_| anyhow::anyhow!("VOYAGE_API_KEY environment variable must be set"))? }; + Ok(Self { api_key, client: reqwest::Client::new(), diff --git a/src/ops/functions/embed_text.rs b/src/ops/functions/embed_text.rs index bd870158..825d98fa 100644 --- a/src/ops/functions/embed_text.rs +++ b/src/ops/functions/embed_text.rs @@ -13,6 +13,7 @@ struct Spec { api_config: Option, output_dimension: Option, task_type: Option, + api_key: Option, } struct Args { @@ -91,9 +92,14 @@ impl SimpleFunctionFactoryBase for Factory { .next_arg("text")? .expect_type(&ValueType::Basic(BasicValueType::Str))? .required()?; - let client = - new_llm_embedding_client(spec.api_type, spec.address.clone(), spec.api_config.clone()) - .await?; + + let client = new_llm_embedding_client( + spec.api_type, + spec.address.clone(), + spec.api_key.clone(), + spec.api_config.clone(), + ) + .await?; let output_dimension = match spec.output_dimension { Some(output_dimension) => output_dimension, None => { @@ -144,6 +150,7 @@ mod tests { api_config: None, output_dimension: None, task_type: None, + api_key: None, }; let factory = Arc::new(Factory); diff --git a/src/ops/functions/extract_by_llm.rs b/src/ops/functions/extract_by_llm.rs index 4dfe9d4d..d929f9ae 100644 --- a/src/ops/functions/extract_by_llm.rs +++ b/src/ops/functions/extract_by_llm.rs @@ -55,6 +55,7 @@ impl Executor { let client = new_llm_generation_client( spec.llm_spec.api_type, spec.llm_spec.address, + spec.llm_spec.api_key, spec.llm_spec.api_config, ) .await?; @@ -204,6 +205,7 @@ mod tests { api_type: crate::llm::LlmApiType::OpenAi, model: "gpt-4o".to_string(), address: None, + api_key: None, api_config: None, }, output_type: output_type_spec, @@ -274,6 +276,7 @@ mod tests { api_type: crate::llm::LlmApiType::OpenAi, model: "gpt-4o".to_string(), address: None, + api_key: None, api_config: None, }, output_type: make_output_type(BasicValueType::Str),