Skip to content

Commit 01bc904

Browse files
authored
Make ai config more intuitive (#46)
* Make ai config more intuitive * Support override default model name * Skip unsupported model test * Fix testings under ai_test feature
1 parent cb5357f commit 01bc904

File tree

10 files changed

+144
-123
lines changed

10 files changed

+144
-123
lines changed

aiscript-runtime/src/config/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ impl AsRef<str> for EnvString {
6666
#[derive(Debug, Deserialize, Default)]
6767
pub struct Config {
6868
#[serde(default)]
69-
pub ai: Option<AiConfig>,
69+
pub ai: AiConfig,
7070
#[serde(default)]
7171
pub database: DatabaseConfig,
7272
#[serde(default)]

aiscript-vm/src/ai/agent.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,6 @@ pub async fn _run_agent<'gc>(
277277
mut agent: Gc<'gc, Agent<'gc>>,
278278
args: Vec<Value<'gc>>,
279279
) -> Value<'gc> {
280-
use super::default_model;
281-
282280
let message = args[0];
283281
let debug = args[1].as_boolean();
284282
let mut history = Vec::new();
@@ -289,11 +287,13 @@ pub async fn _run_agent<'gc>(
289287
tool_calls: None,
290288
tool_call_id: None,
291289
});
292-
let mut client = super::openai_client(state.ai_config.as_ref());
290+
let model_config = state.ai_config.get_model_config(None).unwrap();
291+
let mut client = super::openai_client(&model_config);
292+
let model = model_config.model.unwrap();
293293
loop {
294294
let mut messages = vec![agent.get_instruction_message()];
295295
messages.extend(history.clone());
296-
let mut req = ChatCompletionRequest::new(default_model(state.ai_config.as_ref()), messages);
296+
let mut req = ChatCompletionRequest::new(model.clone(), messages);
297297
let tools = agent.get_tools();
298298
if !tools.is_empty() {
299299
req = req

aiscript-vm/src/ai/mod.rs

Lines changed: 100 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -4,95 +4,131 @@ mod prompt;
44
use std::env;
55

66
pub use agent::{Agent, run_agent};
7-
use openai_api_rs::v1::{api::OpenAIClient, common::GPT3_5_TURBO};
7+
use openai_api_rs::v1::{api::OpenAIClient, common};
88
pub use prompt::{PromptConfig, prompt_with_config};
99

1010
use serde::Deserialize;
1111

12+
// OpenAI
13+
const OPENAI_API_ENDPOINT: &str = "https://api.openai.com/v1";
14+
const OPENAI_DEFAULT_MODEL: &str = common::GPT4;
15+
1216
// Deepseek
1317
const DEEPSEEK_API_ENDPOINT: &str = "https://api.deepseek.com/v1";
14-
const DEEPSEEK_V3: &str = "deepseek-chat";
18+
const DEEPSEEK_DEFAULT_MODEL: &str = "deepseek-chat";
1519

1620
// Anthropic
1721
const ANTHROPIC_API_ENDPOINT: &str = "https://api.anthropic.com/v1";
18-
const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
22+
const ANTHROPIC_DEFAULT_MODEL: &str = "claude-3-5-sonnet-latest";
1923

2024
#[derive(Debug, Clone, Deserialize)]
21-
pub enum AiConfig {
22-
#[serde(rename = "openai")]
23-
OpenAI(ModelConfig),
24-
#[serde(rename = "anthropic")]
25-
Anthropic(ModelConfig),
26-
#[serde(rename = "deepseek")]
27-
DeepSeek(ModelConfig),
25+
pub struct AiConfig {
26+
pub openai: Option<ModelConfig>,
27+
pub anthropic: Option<ModelConfig>,
28+
pub deepseek: Option<ModelConfig>,
29+
}
30+
31+
impl Default for AiConfig {
32+
fn default() -> Self {
33+
Self {
34+
openai: env::var("OPENAI_API_KEY").ok().map(|key| ModelConfig {
35+
api_key: key,
36+
api_endpoint: Some(OPENAI_API_ENDPOINT.to_string()),
37+
model: Some(OPENAI_DEFAULT_MODEL.to_string()),
38+
}),
39+
anthropic: env::var("CLAUDE_API_KEY").ok().map(|key| ModelConfig {
40+
api_key: key,
41+
api_endpoint: Some(ANTHROPIC_API_ENDPOINT.to_string()),
42+
model: Some(ANTHROPIC_DEFAULT_MODEL.to_string()),
43+
}),
44+
deepseek: env::var("DEEPKSEEK_API_KEY").ok().map(|key| ModelConfig {
45+
api_key: key,
46+
api_endpoint: Some(DEEPSEEK_API_ENDPOINT.to_string()),
47+
model: Some(DEEPSEEK_DEFAULT_MODEL.to_string()),
48+
}),
49+
}
50+
}
2851
}
2952

3053
#[derive(Debug, Clone, Deserialize)]
3154
pub struct ModelConfig {
3255
pub api_key: String,
56+
pub api_endpoint: Option<String>,
3357
pub model: Option<String>,
3458
}
3559

36-
impl AiConfig {
37-
pub(crate) fn take_model(&mut self) -> Option<String> {
38-
match self {
39-
Self::OpenAI(ModelConfig { model, .. }) => model.take(),
40-
Self::Anthropic(ModelConfig { model, .. }) => model.take(),
41-
Self::DeepSeek(ModelConfig { model, .. }) => model.take(),
60+
impl Default for ModelConfig {
61+
fn default() -> Self {
62+
ModelConfig {
63+
#[cfg(feature = "ai_test")]
64+
api_key: "".into(),
65+
#[cfg(not(feature = "ai_test"))]
66+
api_key: env::var("OPENAI_API_KEY")
67+
.expect("Expect `OPENAI_API_KEY` environment variable."),
68+
api_endpoint: Some(OPENAI_API_ENDPOINT.to_string()),
69+
model: Some(OPENAI_DEFAULT_MODEL.to_string()),
4270
}
4371
}
44-
45-
pub(crate) fn set_model(&mut self, m: String) {
46-
match self {
47-
Self::OpenAI(ModelConfig { model, .. }) => model.replace(m),
48-
Self::Anthropic(ModelConfig { model, .. }) => model.replace(m),
49-
Self::DeepSeek(ModelConfig { model, .. }) => model.replace(m),
50-
};
51-
}
5272
}
5373

54-
#[allow(unused)]
55-
pub(crate) fn openai_client(config: Option<&AiConfig>) -> OpenAIClient {
56-
match config {
57-
None => OpenAIClient::builder()
58-
.with_api_key(env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"))
59-
.build()
60-
.unwrap(),
61-
Some(AiConfig::OpenAI(model_config)) => {
62-
let api_key = if model_config.api_key.is_empty() {
63-
env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set")
64-
} else {
65-
model_config.api_key.clone()
66-
};
67-
OpenAIClient::builder()
68-
.with_api_key(api_key)
69-
.build()
70-
.unwrap()
74+
impl AiConfig {
75+
pub(crate) fn get_model_config(
76+
&self,
77+
model_name: Option<String>,
78+
) -> Result<ModelConfig, String> {
79+
if let Some(model) = model_name {
80+
match model {
81+
m if m.starts_with("gpt") => {
82+
if let Some(openai) = self.openai.as_ref() {
83+
let mut config = openai.clone();
84+
config.model = Some(m);
85+
Ok(config)
86+
} else {
87+
Ok(ModelConfig::default())
88+
}
89+
}
90+
m if m.starts_with("claude") => {
91+
if let Some(anthropic) = self.anthropic.as_ref() {
92+
let mut config = anthropic.clone();
93+
config.model = Some(m);
94+
Ok(config)
95+
} else {
96+
Ok(ModelConfig {
97+
api_key: env::var("CLAUDE_API_KEY")
98+
.expect("Expect `CLAUDE_API_KEY` environment variable."),
99+
api_endpoint: Some(ANTHROPIC_API_ENDPOINT.to_string()),
100+
model: Some(ANTHROPIC_DEFAULT_MODEL.to_string()),
101+
})
102+
}
103+
}
104+
m if m.starts_with("deepseek") => {
105+
if let Some(deepseek) = self.deepseek.as_ref() {
106+
let mut config = deepseek.clone();
107+
config.model = Some(m);
108+
Ok(config)
109+
} else {
110+
Ok(ModelConfig {
111+
api_key: env::var("DEEPSEEK_API_KEY")
112+
.expect("Expect `DEEPSEEK_API_KEY` environment variable."),
113+
api_endpoint: Some(DEEPSEEK_API_ENDPOINT.to_string()),
114+
model: Some(DEEPSEEK_DEFAULT_MODEL.to_string()),
115+
})
116+
}
117+
}
118+
m => Err(format!("Unsupported model '{m}'.")),
119+
}
120+
} else {
121+
// Default is OpenAI model
122+
Ok(ModelConfig::default())
71123
}
72-
Some(AiConfig::DeepSeek(ModelConfig { api_key, .. })) => OpenAIClient::builder()
73-
.with_endpoint(DEEPSEEK_API_ENDPOINT)
74-
.with_api_key(api_key)
75-
.build()
76-
.unwrap(),
77-
Some(AiConfig::Anthropic(ModelConfig { api_key, .. })) => OpenAIClient::builder()
78-
.with_endpoint(ANTHROPIC_API_ENDPOINT)
79-
.with_api_key(api_key)
80-
.build()
81-
.unwrap(),
82124
}
83125
}
84126

85-
pub(crate) fn default_model(config: Option<&AiConfig>) -> String {
86-
match config {
87-
None => GPT3_5_TURBO.to_string(),
88-
Some(AiConfig::OpenAI(ModelConfig { model, .. })) => {
89-
model.clone().unwrap_or(GPT3_5_TURBO.to_string())
90-
}
91-
Some(AiConfig::DeepSeek(ModelConfig { model, .. })) => {
92-
model.clone().unwrap_or(DEEPSEEK_V3.to_string())
93-
}
94-
Some(AiConfig::Anthropic(ModelConfig { model, .. })) => {
95-
model.clone().unwrap_or(CLAUDE_3_5_SONNET.to_string())
96-
}
97-
}
127+
#[allow(unused)]
128+
pub(crate) fn openai_client(config: &ModelConfig) -> OpenAIClient {
129+
OpenAIClient::builder()
130+
.with_api_key(&config.api_key)
131+
.with_endpoint(config.api_endpoint.as_ref().unwrap())
132+
.build()
133+
.unwrap()
98134
}

aiscript-vm/src/ai/prompt.rs

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,16 @@
1-
use openai_api_rs::v1::common::GPT3_5_TURBO;
21
use tokio::runtime::Handle;
32

4-
use super::{AiConfig, ModelConfig, default_model};
3+
use super::ModelConfig;
54

5+
#[derive(Default)]
66
pub struct PromptConfig {
77
pub input: String,
8-
pub ai_config: Option<AiConfig>,
8+
pub model_config: ModelConfig,
99
pub max_tokens: Option<i64>,
1010
pub temperature: Option<f64>,
1111
pub system_prompt: Option<String>,
1212
}
1313

14-
impl Default for PromptConfig {
15-
fn default() -> Self {
16-
Self {
17-
input: String::new(),
18-
ai_config: Some(AiConfig::OpenAI(ModelConfig {
19-
api_key: Default::default(),
20-
model: Some(GPT3_5_TURBO.to_string()),
21-
})),
22-
max_tokens: Default::default(),
23-
temperature: Default::default(),
24-
system_prompt: Default::default(),
25-
}
26-
}
27-
}
28-
29-
impl PromptConfig {
30-
fn take_model(&mut self) -> String {
31-
self.ai_config
32-
.as_mut()
33-
.and_then(|config| config.take_model())
34-
.unwrap_or_else(|| default_model(self.ai_config.as_ref()))
35-
}
36-
37-
pub(crate) fn set_model(&mut self, model: String) {
38-
if let Some(config) = self.ai_config.as_mut() {
39-
config.set_model(model);
40-
}
41-
}
42-
}
43-
4414
#[cfg(feature = "ai_test")]
4515
async fn _prompt_with_config(config: PromptConfig) -> String {
4616
return format!("AI: {}", config.input);
@@ -49,8 +19,8 @@ async fn _prompt_with_config(config: PromptConfig) -> String {
4919
#[cfg(not(feature = "ai_test"))]
5020
async fn _prompt_with_config(mut config: PromptConfig) -> String {
5121
use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest};
52-
let mut client = super::openai_client(config.ai_config.as_ref());
53-
let model = config.take_model();
22+
let model = config.model_config.model.take().unwrap();
23+
let mut client = super::openai_client(&config.model_config);
5424

5525
// Create system message if provided
5626
let mut messages = Vec::new();

aiscript-vm/src/vm/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ impl Display for VmError {
3636

3737
impl Default for Vm {
3838
fn default() -> Self {
39-
Self::new(None, None, None, None)
39+
Self::new(None, None, None, AiConfig::default())
4040
}
4141
}
4242

@@ -49,7 +49,7 @@ impl Vm {
4949
pg_connection: Option<PgPool>,
5050
sqlite_connection: Option<SqlitePool>,
5151
redis_connection: Option<redis::aio::MultiplexedConnection>,
52-
ai_config: Option<AiConfig>,
52+
ai_config: AiConfig,
5353
) -> Self {
5454
let mut vm = Vm {
5555
arena: Arena::<Rootable![State<'_>]>::new(|mc| {

aiscript-vm/src/vm/state.rs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ pub struct State<'gc> {
110110
pub pg_connection: Option<PgPool>,
111111
pub sqlite_connection: Option<SqlitePool>,
112112
pub redis_connection: Option<redis::aio::MultiplexedConnection>,
113-
pub ai_config: Option<AiConfig>,
113+
pub ai_config: AiConfig,
114114
}
115115

116116
unsafe impl Collect for State<'_> {
@@ -153,7 +153,7 @@ impl<'gc> State<'gc> {
153153
pg_connection: None,
154154
sqlite_connection: None,
155155
redis_connection: None,
156-
ai_config: None,
156+
ai_config: AiConfig::default(),
157157
}
158158
}
159159

@@ -1013,22 +1013,29 @@ impl<'gc> State<'gc> {
10131013
let result = match value {
10141014
// Simple string case
10151015
Value::String(s) => {
1016-
let mut config = PromptConfig {
1016+
let config = PromptConfig {
10171017
input: s.to_str().unwrap().to_string(),
1018+
model_config: self
1019+
.ai_config
1020+
.get_model_config(None)
1021+
.map_err(VmError::RuntimeError)?,
10181022
..Default::default()
10191023
};
1020-
if let Some(ai_cfg) = &self.ai_config {
1021-
config.ai_config = Some(ai_cfg.clone());
1022-
}
10231024
ai::prompt_with_config(config)
10241025
}
10251026
// Object config case
10261027
Value::Object(obj) => {
10271028
let mut config = PromptConfig::default();
1028-
if let Some(ai_cfg) = &self.ai_config {
1029-
config.ai_config = Some(ai_cfg.clone());
1030-
}
10311029
let obj_ref = obj.borrow();
1030+
// Extract model (optional)
1031+
if let Some(Value::String(model)) =
1032+
obj_ref.fields.get(&self.intern(b"model"))
1033+
{
1034+
config.model_config = self
1035+
.ai_config
1036+
.get_model_config(Some(model.to_str().unwrap().to_string()))
1037+
.map_err(VmError::RuntimeError)?
1038+
}
10321039

10331040
// Extract input (required)
10341041
if let Some(Value::String(input)) =
@@ -1041,13 +1048,6 @@ impl<'gc> State<'gc> {
10411048
));
10421049
}
10431050

1044-
// Extract model (optional)
1045-
if let Some(Value::String(model)) =
1046-
obj_ref.fields.get(&self.intern(b"model"))
1047-
{
1048-
config.set_model(model.to_str().unwrap().to_string());
1049-
}
1050-
10511051
// Extract max_tokens (optional)
10521052
if let Some(Value::Number(tokens)) =
10531053
obj_ref.fields.get(&self.intern(b"max_tokens"))

examples/claude.ai

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
let a = prompt {
2+
input: "What is rust?",
3+
model: "claude-3-7-sonnet-latest"
4+
};
5+
print(a);

examples/project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,7 @@ client_id = "123"
1919
client_secret = "abc"
2020
redirect_url = "http://localhost:8080/callback"
2121
scopes = ["email"]
22+
23+
[ai.anthropic]
24+
api_key = "$CLAUDE_API_KEY"
25+
model = "claude-3-5-sonnet-latest"

tests/integration/ai/prompt.ai

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
let p = "What is AIScript?";
22
let a = prompt p;
3-
print(a); // expect: AI: What is AIScript?
3+
print(a); // expect: AI: What is AIScript?

0 commit comments

Comments
 (0)