diff --git a/aiscript-runtime/src/config/mod.rs b/aiscript-runtime/src/config/mod.rs index cc53c5d..8057523 100644 --- a/aiscript-runtime/src/config/mod.rs +++ b/aiscript-runtime/src/config/mod.rs @@ -66,7 +66,7 @@ impl AsRef for EnvString { #[derive(Debug, Deserialize, Default)] pub struct Config { #[serde(default)] - pub ai: Option, + pub ai: AiConfig, #[serde(default)] pub database: DatabaseConfig, #[serde(default)] diff --git a/aiscript-vm/src/ai/agent.rs b/aiscript-vm/src/ai/agent.rs index 90ca33d..12e0b86 100644 --- a/aiscript-vm/src/ai/agent.rs +++ b/aiscript-vm/src/ai/agent.rs @@ -277,8 +277,6 @@ pub async fn _run_agent<'gc>( mut agent: Gc<'gc, Agent<'gc>>, args: Vec>, ) -> Value<'gc> { - use super::default_model; - let message = args[0]; let debug = args[1].as_boolean(); let mut history = Vec::new(); @@ -289,11 +287,13 @@ pub async fn _run_agent<'gc>( tool_calls: None, tool_call_id: None, }); - let mut client = super::openai_client(state.ai_config.as_ref()); + let model_config = state.ai_config.get_model_config(None).unwrap(); + let mut client = super::openai_client(&model_config); + let model = model_config.model.unwrap(); loop { let mut messages = vec![agent.get_instruction_message()]; messages.extend(history.clone()); - let mut req = ChatCompletionRequest::new(default_model(state.ai_config.as_ref()), messages); + let mut req = ChatCompletionRequest::new(model.clone(), messages); let tools = agent.get_tools(); if !tools.is_empty() { req = req diff --git a/aiscript-vm/src/ai/mod.rs b/aiscript-vm/src/ai/mod.rs index 61a9f80..96ac34a 100644 --- a/aiscript-vm/src/ai/mod.rs +++ b/aiscript-vm/src/ai/mod.rs @@ -4,95 +4,131 @@ mod prompt; use std::env; pub use agent::{Agent, run_agent}; -use openai_api_rs::v1::{api::OpenAIClient, common::GPT3_5_TURBO}; +use openai_api_rs::v1::{api::OpenAIClient, common}; pub use prompt::{PromptConfig, prompt_with_config}; use serde::Deserialize; +// OpenAI +const OPENAI_API_ENDPOINT: &str = "https://api.openai.com/v1"; +const OPENAI_DEFAULT_MODEL: &str = common::GPT4; + // Deepseek const DEEPSEEK_API_ENDPOINT: &str = "https://api.deepseek.com/v1"; -const DEEPSEEK_V3: &str = "deepseek-chat"; +const DEEPSEEK_DEFAULT_MODEL: &str = "deepseek-chat"; // Anthropic const ANTHROPIC_API_ENDPOINT: &str = "https://api.anthropic.com/v1"; -const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest"; +const ANTHROPIC_DEFAULT_MODEL: &str = "claude-3-5-sonnet-latest"; #[derive(Debug, Clone, Deserialize)] -pub enum AiConfig { - #[serde(rename = "openai")] - OpenAI(ModelConfig), - #[serde(rename = "anthropic")] - Anthropic(ModelConfig), - #[serde(rename = "deepseek")] - DeepSeek(ModelConfig), +pub struct AiConfig { + pub openai: Option, + pub anthropic: Option, + pub deepseek: Option, +} + +impl Default for AiConfig { + fn default() -> Self { + Self { + openai: env::var("OPENAI_API_KEY").ok().map(|key| ModelConfig { + api_key: key, + api_endpoint: Some(OPENAI_API_ENDPOINT.to_string()), + model: Some(OPENAI_DEFAULT_MODEL.to_string()), + }), + anthropic: env::var("CLAUDE_API_KEY").ok().map(|key| ModelConfig { + api_key: key, + api_endpoint: Some(ANTHROPIC_API_ENDPOINT.to_string()), + model: Some(ANTHROPIC_DEFAULT_MODEL.to_string()), + }), + deepseek: env::var("DEEPKSEEK_API_KEY").ok().map(|key| ModelConfig { + api_key: key, + api_endpoint: Some(DEEPSEEK_API_ENDPOINT.to_string()), + model: Some(DEEPSEEK_DEFAULT_MODEL.to_string()), + }), + } + } } #[derive(Debug, Clone, Deserialize)] pub struct ModelConfig { pub api_key: String, + pub api_endpoint: Option, pub model: Option, } -impl AiConfig { - pub(crate) fn take_model(&mut self) -> Option { - match self { - Self::OpenAI(ModelConfig { model, .. }) => model.take(), - Self::Anthropic(ModelConfig { model, .. }) => model.take(), - Self::DeepSeek(ModelConfig { model, .. }) => model.take(), +impl Default for ModelConfig { + fn default() -> Self { + ModelConfig { + #[cfg(feature = "ai_test")] + api_key: "".into(), + #[cfg(not(feature = "ai_test"))] + api_key: env::var("OPENAI_API_KEY") + .expect("Expect `OPENAI_API_KEY` environment variable."), + api_endpoint: Some(OPENAI_API_ENDPOINT.to_string()), + model: Some(OPENAI_DEFAULT_MODEL.to_string()), } } - - pub(crate) fn set_model(&mut self, m: String) { - match self { - Self::OpenAI(ModelConfig { model, .. }) => model.replace(m), - Self::Anthropic(ModelConfig { model, .. }) => model.replace(m), - Self::DeepSeek(ModelConfig { model, .. }) => model.replace(m), - }; - } } -#[allow(unused)] -pub(crate) fn openai_client(config: Option<&AiConfig>) -> OpenAIClient { - match config { - None => OpenAIClient::builder() - .with_api_key(env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set")) - .build() - .unwrap(), - Some(AiConfig::OpenAI(model_config)) => { - let api_key = if model_config.api_key.is_empty() { - env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set") - } else { - model_config.api_key.clone() - }; - OpenAIClient::builder() - .with_api_key(api_key) - .build() - .unwrap() +impl AiConfig { + pub(crate) fn get_model_config( + &self, + model_name: Option, + ) -> Result { + if let Some(model) = model_name { + match model { + m if m.starts_with("gpt") => { + if let Some(openai) = self.openai.as_ref() { + let mut config = openai.clone(); + config.model = Some(m); + Ok(config) + } else { + Ok(ModelConfig::default()) + } + } + m if m.starts_with("claude") => { + if let Some(anthropic) = self.anthropic.as_ref() { + let mut config = anthropic.clone(); + config.model = Some(m); + Ok(config) + } else { + Ok(ModelConfig { + api_key: env::var("CLAUDE_API_KEY") + .expect("Expect `CLAUDE_API_KEY` environment variable."), + api_endpoint: Some(ANTHROPIC_API_ENDPOINT.to_string()), + model: Some(ANTHROPIC_DEFAULT_MODEL.to_string()), + }) + } + } + m if m.starts_with("deepseek") => { + if let Some(deepseek) = self.deepseek.as_ref() { + let mut config = deepseek.clone(); + config.model = Some(m); + Ok(config) + } else { + Ok(ModelConfig { + api_key: env::var("DEEPSEEK_API_KEY") + .expect("Expect `DEEPSEEK_API_KEY` environment variable."), + api_endpoint: Some(DEEPSEEK_API_ENDPOINT.to_string()), + model: Some(DEEPSEEK_DEFAULT_MODEL.to_string()), + }) + } + } + m => Err(format!("Unsupported model '{m}'.")), + } + } else { + // Default is OpenAI model + Ok(ModelConfig::default()) } - Some(AiConfig::DeepSeek(ModelConfig { api_key, .. })) => OpenAIClient::builder() - .with_endpoint(DEEPSEEK_API_ENDPOINT) - .with_api_key(api_key) - .build() - .unwrap(), - Some(AiConfig::Anthropic(ModelConfig { api_key, .. })) => OpenAIClient::builder() - .with_endpoint(ANTHROPIC_API_ENDPOINT) - .with_api_key(api_key) - .build() - .unwrap(), } } -pub(crate) fn default_model(config: Option<&AiConfig>) -> String { - match config { - None => GPT3_5_TURBO.to_string(), - Some(AiConfig::OpenAI(ModelConfig { model, .. })) => { - model.clone().unwrap_or(GPT3_5_TURBO.to_string()) - } - Some(AiConfig::DeepSeek(ModelConfig { model, .. })) => { - model.clone().unwrap_or(DEEPSEEK_V3.to_string()) - } - Some(AiConfig::Anthropic(ModelConfig { model, .. })) => { - model.clone().unwrap_or(CLAUDE_3_5_SONNET.to_string()) - } - } +#[allow(unused)] +pub(crate) fn openai_client(config: &ModelConfig) -> OpenAIClient { + OpenAIClient::builder() + .with_api_key(&config.api_key) + .with_endpoint(config.api_endpoint.as_ref().unwrap()) + .build() + .unwrap() } diff --git a/aiscript-vm/src/ai/prompt.rs b/aiscript-vm/src/ai/prompt.rs index 68a5eb4..02b96a6 100644 --- a/aiscript-vm/src/ai/prompt.rs +++ b/aiscript-vm/src/ai/prompt.rs @@ -1,46 +1,16 @@ -use openai_api_rs::v1::common::GPT3_5_TURBO; use tokio::runtime::Handle; -use super::{AiConfig, ModelConfig, default_model}; +use super::ModelConfig; +#[derive(Default)] pub struct PromptConfig { pub input: String, - pub ai_config: Option, + pub model_config: ModelConfig, pub max_tokens: Option, pub temperature: Option, pub system_prompt: Option, } -impl Default for PromptConfig { - fn default() -> Self { - Self { - input: String::new(), - ai_config: Some(AiConfig::OpenAI(ModelConfig { - api_key: Default::default(), - model: Some(GPT3_5_TURBO.to_string()), - })), - max_tokens: Default::default(), - temperature: Default::default(), - system_prompt: Default::default(), - } - } -} - -impl PromptConfig { - fn take_model(&mut self) -> String { - self.ai_config - .as_mut() - .and_then(|config| config.take_model()) - .unwrap_or_else(|| default_model(self.ai_config.as_ref())) - } - - pub(crate) fn set_model(&mut self, model: String) { - if let Some(config) = self.ai_config.as_mut() { - config.set_model(model); - } - } -} - #[cfg(feature = "ai_test")] async fn _prompt_with_config(config: PromptConfig) -> String { return format!("AI: {}", config.input); @@ -49,8 +19,8 @@ async fn _prompt_with_config(config: PromptConfig) -> String { #[cfg(not(feature = "ai_test"))] async fn _prompt_with_config(mut config: PromptConfig) -> String { use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; - let mut client = super::openai_client(config.ai_config.as_ref()); - let model = config.take_model(); + let model = config.model_config.model.take().unwrap(); + let mut client = super::openai_client(&config.model_config); // Create system message if provided let mut messages = Vec::new(); diff --git a/aiscript-vm/src/vm/mod.rs b/aiscript-vm/src/vm/mod.rs index fe22a87..8c0675f 100644 --- a/aiscript-vm/src/vm/mod.rs +++ b/aiscript-vm/src/vm/mod.rs @@ -36,7 +36,7 @@ impl Display for VmError { impl Default for Vm { fn default() -> Self { - Self::new(None, None, None, None) + Self::new(None, None, None, AiConfig::default()) } } @@ -49,7 +49,7 @@ impl Vm { pg_connection: Option, sqlite_connection: Option, redis_connection: Option, - ai_config: Option, + ai_config: AiConfig, ) -> Self { let mut vm = Vm { arena: Arena::]>::new(|mc| { diff --git a/aiscript-vm/src/vm/state.rs b/aiscript-vm/src/vm/state.rs index 6f9205a..a31b5bb 100644 --- a/aiscript-vm/src/vm/state.rs +++ b/aiscript-vm/src/vm/state.rs @@ -110,7 +110,7 @@ pub struct State<'gc> { pub pg_connection: Option, pub sqlite_connection: Option, pub redis_connection: Option, - pub ai_config: Option, + pub ai_config: AiConfig, } unsafe impl Collect for State<'_> { @@ -153,7 +153,7 @@ impl<'gc> State<'gc> { pg_connection: None, sqlite_connection: None, redis_connection: None, - ai_config: None, + ai_config: AiConfig::default(), } } @@ -1013,22 +1013,29 @@ impl<'gc> State<'gc> { let result = match value { // Simple string case Value::String(s) => { - let mut config = PromptConfig { + let config = PromptConfig { input: s.to_str().unwrap().to_string(), + model_config: self + .ai_config + .get_model_config(None) + .map_err(VmError::RuntimeError)?, ..Default::default() }; - if let Some(ai_cfg) = &self.ai_config { - config.ai_config = Some(ai_cfg.clone()); - } ai::prompt_with_config(config) } // Object config case Value::Object(obj) => { let mut config = PromptConfig::default(); - if let Some(ai_cfg) = &self.ai_config { - config.ai_config = Some(ai_cfg.clone()); - } let obj_ref = obj.borrow(); + // Extract model (optional) + if let Some(Value::String(model)) = + obj_ref.fields.get(&self.intern(b"model")) + { + config.model_config = self + .ai_config + .get_model_config(Some(model.to_str().unwrap().to_string())) + .map_err(VmError::RuntimeError)? + } // Extract input (required) if let Some(Value::String(input)) = @@ -1041,13 +1048,6 @@ impl<'gc> State<'gc> { )); } - // Extract model (optional) - if let Some(Value::String(model)) = - obj_ref.fields.get(&self.intern(b"model")) - { - config.set_model(model.to_str().unwrap().to_string()); - } - // Extract max_tokens (optional) if let Some(Value::Number(tokens)) = obj_ref.fields.get(&self.intern(b"max_tokens")) diff --git a/examples/claude.ai b/examples/claude.ai new file mode 100644 index 0000000..494962b --- /dev/null +++ b/examples/claude.ai @@ -0,0 +1,5 @@ +let a = prompt { + input: "What is rust?", + model: "claude-3-7-sonnet-latest" +}; +print(a); \ No newline at end of file diff --git a/examples/project.toml b/examples/project.toml index 5fd5c75..8b90f8d 100644 --- a/examples/project.toml +++ b/examples/project.toml @@ -19,3 +19,7 @@ client_id = "123" client_secret = "abc" redirect_url = "http://localhost:8080/callback" scopes = ["email"] + +[ai.anthropic] +api_key = "$CLAUDE_API_KEY" +model = "claude-3-5-sonnet-latest" diff --git a/tests/integration/ai/prompt.ai b/tests/integration/ai/prompt.ai index 0418386..9f48fc5 100644 --- a/tests/integration/ai/prompt.ai +++ b/tests/integration/ai/prompt.ai @@ -1,3 +1,3 @@ let p = "What is AIScript?"; let a = prompt p; -print(a); // expect: AI: What is AIScript? \ No newline at end of file +print(a); // expect: AI: What is AIScript? diff --git a/tests/integration/ai/unsupported_model.ai b/tests/integration/ai/unsupported_model.ai new file mode 100644 index 0000000..ab0b5c5 --- /dev/null +++ b/tests/integration/ai/unsupported_model.ai @@ -0,0 +1,6 @@ +// ignore +let a = prompt { + input: "hi", + model: "invalid-model", +}; +// expect runtime error: Unsupported model 'invalid-model'.