Skip to content

Make ai config more intuitive #46

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aiscript-runtime/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl AsRef<str> for EnvString {
#[derive(Debug, Deserialize, Default)]
pub struct Config {
#[serde(default)]
pub ai: Option<AiConfig>,
pub ai: AiConfig,
#[serde(default)]
pub database: DatabaseConfig,
#[serde(default)]
Expand Down
8 changes: 4 additions & 4 deletions aiscript-vm/src/ai/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,6 @@ pub async fn _run_agent<'gc>(
mut agent: Gc<'gc, Agent<'gc>>,
args: Vec<Value<'gc>>,
) -> Value<'gc> {
use super::default_model;

let message = args[0];
let debug = args[1].as_boolean();
let mut history = Vec::new();
Expand All @@ -289,11 +287,13 @@ pub async fn _run_agent<'gc>(
tool_calls: None,
tool_call_id: None,
});
let mut client = super::openai_client(state.ai_config.as_ref());
let model_config = state.ai_config.get_model_config(None).unwrap();
let mut client = super::openai_client(&model_config);
let model = model_config.model.unwrap();
loop {
let mut messages = vec![agent.get_instruction_message()];
messages.extend(history.clone());
let mut req = ChatCompletionRequest::new(default_model(state.ai_config.as_ref()), messages);
let mut req = ChatCompletionRequest::new(model.clone(), messages);
let tools = agent.get_tools();
if !tools.is_empty() {
req = req
Expand Down
164 changes: 100 additions & 64 deletions aiscript-vm/src/ai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,95 +4,131 @@ mod prompt;
use std::env;

pub use agent::{Agent, run_agent};
use openai_api_rs::v1::{api::OpenAIClient, common::GPT3_5_TURBO};
use openai_api_rs::v1::{api::OpenAIClient, common};
pub use prompt::{PromptConfig, prompt_with_config};

use serde::Deserialize;

// OpenAI
const OPENAI_API_ENDPOINT: &str = "https://api.openai.com/v1";
const OPENAI_DEFAULT_MODEL: &str = common::GPT4;

// Deepseek
const DEEPSEEK_API_ENDPOINT: &str = "https://api.deepseek.com/v1";
const DEEPSEEK_V3: &str = "deepseek-chat";
const DEEPSEEK_DEFAULT_MODEL: &str = "deepseek-chat";

// Anthropic
const ANTHROPIC_API_ENDPOINT: &str = "https://api.anthropic.com/v1";
const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
const ANTHROPIC_DEFAULT_MODEL: &str = "claude-3-5-sonnet-latest";

#[derive(Debug, Clone, Deserialize)]
pub enum AiConfig {
#[serde(rename = "openai")]
OpenAI(ModelConfig),
#[serde(rename = "anthropic")]
Anthropic(ModelConfig),
#[serde(rename = "deepseek")]
DeepSeek(ModelConfig),
pub struct AiConfig {
pub openai: Option<ModelConfig>,
pub anthropic: Option<ModelConfig>,
pub deepseek: Option<ModelConfig>,
}

impl Default for AiConfig {
fn default() -> Self {
Self {
openai: env::var("OPENAI_API_KEY").ok().map(|key| ModelConfig {
api_key: key,
api_endpoint: Some(OPENAI_API_ENDPOINT.to_string()),
model: Some(OPENAI_DEFAULT_MODEL.to_string()),
}),
anthropic: env::var("CLAUDE_API_KEY").ok().map(|key| ModelConfig {
api_key: key,
api_endpoint: Some(ANTHROPIC_API_ENDPOINT.to_string()),
model: Some(ANTHROPIC_DEFAULT_MODEL.to_string()),
}),
deepseek: env::var("DEEPKSEEK_API_KEY").ok().map(|key| ModelConfig {
api_key: key,
api_endpoint: Some(DEEPSEEK_API_ENDPOINT.to_string()),
model: Some(DEEPSEEK_DEFAULT_MODEL.to_string()),
}),
}
}
}

#[derive(Debug, Clone, Deserialize)]
pub struct ModelConfig {
pub api_key: String,
pub api_endpoint: Option<String>,
pub model: Option<String>,
}

impl AiConfig {
pub(crate) fn take_model(&mut self) -> Option<String> {
match self {
Self::OpenAI(ModelConfig { model, .. }) => model.take(),
Self::Anthropic(ModelConfig { model, .. }) => model.take(),
Self::DeepSeek(ModelConfig { model, .. }) => model.take(),
impl Default for ModelConfig {
fn default() -> Self {
ModelConfig {
#[cfg(feature = "ai_test")]
api_key: "".into(),
#[cfg(not(feature = "ai_test"))]
api_key: env::var("OPENAI_API_KEY")
.expect("Expect `OPENAI_API_KEY` environment variable."),
api_endpoint: Some(OPENAI_API_ENDPOINT.to_string()),
model: Some(OPENAI_DEFAULT_MODEL.to_string()),
}
}

pub(crate) fn set_model(&mut self, m: String) {
match self {
Self::OpenAI(ModelConfig { model, .. }) => model.replace(m),
Self::Anthropic(ModelConfig { model, .. }) => model.replace(m),
Self::DeepSeek(ModelConfig { model, .. }) => model.replace(m),
};
}
}

#[allow(unused)]
pub(crate) fn openai_client(config: Option<&AiConfig>) -> OpenAIClient {
match config {
None => OpenAIClient::builder()
.with_api_key(env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"))
.build()
.unwrap(),
Some(AiConfig::OpenAI(model_config)) => {
let api_key = if model_config.api_key.is_empty() {
env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set")
} else {
model_config.api_key.clone()
};
OpenAIClient::builder()
.with_api_key(api_key)
.build()
.unwrap()
impl AiConfig {
pub(crate) fn get_model_config(
&self,
model_name: Option<String>,
) -> Result<ModelConfig, String> {
if let Some(model) = model_name {
match model {
m if m.starts_with("gpt") => {
if let Some(openai) = self.openai.as_ref() {
let mut config = openai.clone();
config.model = Some(m);
Ok(config)
} else {
Ok(ModelConfig::default())
}
}
m if m.starts_with("claude") => {
if let Some(anthropic) = self.anthropic.as_ref() {
let mut config = anthropic.clone();
config.model = Some(m);
Ok(config)
} else {
Ok(ModelConfig {
api_key: env::var("CLAUDE_API_KEY")
.expect("Expect `CLAUDE_API_KEY` environment variable."),
api_endpoint: Some(ANTHROPIC_API_ENDPOINT.to_string()),
model: Some(ANTHROPIC_DEFAULT_MODEL.to_string()),
})
}
}
m if m.starts_with("deepseek") => {
if let Some(deepseek) = self.deepseek.as_ref() {
let mut config = deepseek.clone();
config.model = Some(m);
Ok(config)
} else {
Ok(ModelConfig {
api_key: env::var("DEEPSEEK_API_KEY")
.expect("Expect `DEEPSEEK_API_KEY` environment variable."),
api_endpoint: Some(DEEPSEEK_API_ENDPOINT.to_string()),
model: Some(DEEPSEEK_DEFAULT_MODEL.to_string()),
})
}
}
m => Err(format!("Unsupported model '{m}'.")),
}
} else {
// Default is OpenAI model
Ok(ModelConfig::default())
}
Some(AiConfig::DeepSeek(ModelConfig { api_key, .. })) => OpenAIClient::builder()
.with_endpoint(DEEPSEEK_API_ENDPOINT)
.with_api_key(api_key)
.build()
.unwrap(),
Some(AiConfig::Anthropic(ModelConfig { api_key, .. })) => OpenAIClient::builder()
.with_endpoint(ANTHROPIC_API_ENDPOINT)
.with_api_key(api_key)
.build()
.unwrap(),
}
}

pub(crate) fn default_model(config: Option<&AiConfig>) -> String {
match config {
None => GPT3_5_TURBO.to_string(),
Some(AiConfig::OpenAI(ModelConfig { model, .. })) => {
model.clone().unwrap_or(GPT3_5_TURBO.to_string())
}
Some(AiConfig::DeepSeek(ModelConfig { model, .. })) => {
model.clone().unwrap_or(DEEPSEEK_V3.to_string())
}
Some(AiConfig::Anthropic(ModelConfig { model, .. })) => {
model.clone().unwrap_or(CLAUDE_3_5_SONNET.to_string())
}
}
#[allow(unused)]
pub(crate) fn openai_client(config: &ModelConfig) -> OpenAIClient {
OpenAIClient::builder()
.with_api_key(&config.api_key)
.with_endpoint(config.api_endpoint.as_ref().unwrap())
.build()
.unwrap()
}
40 changes: 5 additions & 35 deletions aiscript-vm/src/ai/prompt.rs
Original file line number Diff line number Diff line change
@@ -1,46 +1,16 @@
use openai_api_rs::v1::common::GPT3_5_TURBO;
use tokio::runtime::Handle;

use super::{AiConfig, ModelConfig, default_model};
use super::ModelConfig;

#[derive(Default)]
pub struct PromptConfig {
pub input: String,
pub ai_config: Option<AiConfig>,
pub model_config: ModelConfig,
pub max_tokens: Option<i64>,
pub temperature: Option<f64>,
pub system_prompt: Option<String>,
}

impl Default for PromptConfig {
fn default() -> Self {
Self {
input: String::new(),
ai_config: Some(AiConfig::OpenAI(ModelConfig {
api_key: Default::default(),
model: Some(GPT3_5_TURBO.to_string()),
})),
max_tokens: Default::default(),
temperature: Default::default(),
system_prompt: Default::default(),
}
}
}

impl PromptConfig {
fn take_model(&mut self) -> String {
self.ai_config
.as_mut()
.and_then(|config| config.take_model())
.unwrap_or_else(|| default_model(self.ai_config.as_ref()))
}

pub(crate) fn set_model(&mut self, model: String) {
if let Some(config) = self.ai_config.as_mut() {
config.set_model(model);
}
}
}

#[cfg(feature = "ai_test")]
async fn _prompt_with_config(config: PromptConfig) -> String {
return format!("AI: {}", config.input);
Expand All @@ -49,8 +19,8 @@ async fn _prompt_with_config(config: PromptConfig) -> String {
#[cfg(not(feature = "ai_test"))]
async fn _prompt_with_config(mut config: PromptConfig) -> String {
use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest};
let mut client = super::openai_client(config.ai_config.as_ref());
let model = config.take_model();
let model = config.model_config.model.take().unwrap();
let mut client = super::openai_client(&config.model_config);

// Create system message if provided
let mut messages = Vec::new();
Expand Down
4 changes: 2 additions & 2 deletions aiscript-vm/src/vm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl Display for VmError {

impl Default for Vm {
fn default() -> Self {
Self::new(None, None, None, None)
Self::new(None, None, None, AiConfig::default())
}
}

Expand All @@ -49,7 +49,7 @@ impl Vm {
pg_connection: Option<PgPool>,
sqlite_connection: Option<SqlitePool>,
redis_connection: Option<redis::aio::MultiplexedConnection>,
ai_config: Option<AiConfig>,
ai_config: AiConfig,
) -> Self {
let mut vm = Vm {
arena: Arena::<Rootable![State<'_>]>::new(|mc| {
Expand Down
32 changes: 16 additions & 16 deletions aiscript-vm/src/vm/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ pub struct State<'gc> {
pub pg_connection: Option<PgPool>,
pub sqlite_connection: Option<SqlitePool>,
pub redis_connection: Option<redis::aio::MultiplexedConnection>,
pub ai_config: Option<AiConfig>,
pub ai_config: AiConfig,
}

unsafe impl Collect for State<'_> {
Expand Down Expand Up @@ -153,7 +153,7 @@ impl<'gc> State<'gc> {
pg_connection: None,
sqlite_connection: None,
redis_connection: None,
ai_config: None,
ai_config: AiConfig::default(),
}
}

Expand Down Expand Up @@ -1013,22 +1013,29 @@ impl<'gc> State<'gc> {
let result = match value {
// Simple string case
Value::String(s) => {
let mut config = PromptConfig {
let config = PromptConfig {
input: s.to_str().unwrap().to_string(),
model_config: self
.ai_config
.get_model_config(None)
.map_err(VmError::RuntimeError)?,
..Default::default()
};
if let Some(ai_cfg) = &self.ai_config {
config.ai_config = Some(ai_cfg.clone());
}
ai::prompt_with_config(config)
}
// Object config case
Value::Object(obj) => {
let mut config = PromptConfig::default();
if let Some(ai_cfg) = &self.ai_config {
config.ai_config = Some(ai_cfg.clone());
}
let obj_ref = obj.borrow();
// Extract model (optional)
if let Some(Value::String(model)) =
obj_ref.fields.get(&self.intern(b"model"))
{
config.model_config = self
.ai_config
.get_model_config(Some(model.to_str().unwrap().to_string()))
.map_err(VmError::RuntimeError)?
}

// Extract input (required)
if let Some(Value::String(input)) =
Expand All @@ -1041,13 +1048,6 @@ impl<'gc> State<'gc> {
));
}

// Extract model (optional)
if let Some(Value::String(model)) =
obj_ref.fields.get(&self.intern(b"model"))
{
config.set_model(model.to_str().unwrap().to_string());
}

// Extract max_tokens (optional)
if let Some(Value::Number(tokens)) =
obj_ref.fields.get(&self.intern(b"max_tokens"))
Expand Down
5 changes: 5 additions & 0 deletions examples/claude.ai
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
let a = prompt {
input: "What is rust?",
model: "claude-3-7-sonnet-latest"
};
print(a);
4 changes: 4 additions & 0 deletions examples/project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ client_id = "123"
client_secret = "abc"
redirect_url = "http://localhost:8080/callback"
scopes = ["email"]

[ai.anthropic]
api_key = "$CLAUDE_API_KEY"
model = "claude-3-5-sonnet-latest"
2 changes: 1 addition & 1 deletion tests/integration/ai/prompt.ai
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
let p = "What is AIScript?";
let a = prompt p;
print(a); // expect: AI: What is AIScript?
print(a); // expect: AI: What is AIScript?
Loading
Loading