Skip to content

Commit 400c373

Browse files
added the count flag and made the code scalable to add future flags if needed (#35)
1 parent 8a1fcfc commit 400c373

4 files changed

Lines changed: 53 additions & 29 deletions

File tree

tiles/src/commands/mod.rs

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,10 @@
11
// Module that handles CLI commands
22

3-
use anyhow::Result;
4-
use tilekit::{modelfile, modelfile::Modelfile};
53
use tiles::runtime::Runtime;
64
use tiles::{core::health, runtime::RunArgs};
7-
const DEFAULT_MODELFILE: &str = "
8-
FROM driaforall/mem-agent-mlx-4bit
9-
";
105

11-
pub async fn run(runtime: &Runtime, modelfile: Option<String>) {
12-
let modelfile_parse_result: Result<Modelfile, String> = if let Some(modelfile_str) = modelfile {
13-
modelfile::parse_from_file(modelfile_str.as_str())
14-
} else {
15-
modelfile::parse(DEFAULT_MODELFILE)
16-
};
17-
match modelfile_parse_result {
18-
Ok(modelfile) => {
19-
let run_args = RunArgs { modelfile };
20-
runtime.run(run_args).await;
21-
}
22-
Err(_err) => println!("Invalid Modelfile"),
23-
}
6+
pub async fn run(runtime: &Runtime, run_args: RunArgs) {
7+
runtime.run(run_args).await;
248
}
259

2610
pub fn check_health() {

tiles/src/main.rs

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::error::Error;
22

33
use clap::{Args, Parser, Subcommand};
4-
use tiles::runtime::build_runtime;
4+
use tiles::runtime::{build_runtime, RunArgs};
55
mod commands;
66
#[derive(Debug, Parser)]
77
#[command(name = "tiles")]
@@ -14,7 +14,13 @@ struct Cli {
1414
#[derive(Subcommand, Debug)]
1515
enum Commands {
1616
/// Runs the given Modelfile (runs the default model if none passed)
17-
Run { modelfile_path: Option<String> },
17+
Run {
18+
/// Path to the Modelfile (uses default model if not provided)
19+
modelfile_path: Option<String>,
20+
21+
#[command(flatten)]
22+
flags: RunFlags,
23+
},
1824

1925
/// Checks the status of dependencies
2026
Health,
@@ -23,6 +29,17 @@ enum Commands {
2329
Server(ServerArgs),
2430
}
2531

32+
#[derive(Debug, Args)]
33+
struct RunFlags {
34+
/// Number of chat retries before giving up (default: 6)
35+
#[arg(short = 'r', long, default_value_t = 6)]
36+
retry_count: u32,
37+
38+
// Future flags go here:
39+
// #[arg(long, default_value_t = 6969)]
40+
// port: u16,
41+
}
42+
2643
#[derive(Debug, Args)]
2744
#[command(args_conflicts_with_subcommands = true)]
2845
#[command(flatten_help = true)]
@@ -44,8 +61,12 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
4461
let cli = Cli::parse();
4562
let runtime = build_runtime();
4663
match cli.command {
47-
Commands::Run { modelfile_path } => {
48-
commands::run(&runtime, modelfile_path).await;
64+
Commands::Run { modelfile_path, flags } => {
65+
let run_args = RunArgs {
66+
modelfile_path,
67+
retry_count: flags.retry_count,
68+
};
69+
commands::run(&runtime, run_args).await;
4970
}
5071
Commands::Health => {
5172
commands::check_health();

tiles/src/runtime/mlx.rs

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,28 @@ impl MLXRuntime {
3434
}
3535

3636
pub async fn run(&self, run_args: super::RunArgs) {
37-
let model = run_args.modelfile.from.as_ref().unwrap();
37+
const DEFAULT_MODELFILE: &str = "FROM driaforall/mem-agent-mlx-4bit";
38+
39+
// Parse modelfile
40+
let modelfile_parse_result = if let Some(modelfile_str) = run_args.modelfile_path {
41+
tilekit::modelfile::parse_from_file(modelfile_str.as_str())
42+
} else {
43+
tilekit::modelfile::parse(DEFAULT_MODELFILE)
44+
};
45+
46+
let modelfile = match modelfile_parse_result {
47+
Ok(mf) => mf,
48+
Err(_err) => {
49+
println!("Invalid Modelfile");
50+
return;
51+
}
52+
};
53+
54+
let model = modelfile.from.as_ref().unwrap();
3855
if model.starts_with("driaforall/mem-agent") {
39-
let _res = run_model_with_server(self, run_args.modelfile).await;
56+
let _res = run_model_with_server(self, modelfile, run_args.retry_count).await;
4057
} else {
41-
run_model_by_sub_process(run_args.modelfile);
58+
run_model_by_sub_process(modelfile);
4259
}
4360
}
4461

@@ -155,6 +172,7 @@ fn run_model_by_sub_process(modelfile: Modelfile) {
155172
async fn run_model_with_server(
156173
mlx_runtime: &MLXRuntime,
157174
modelfile: Modelfile,
175+
retry_count: u32,
158176
) -> reqwest::Result<()> {
159177
if !cfg!(debug_assertions) {
160178
let _res = mlx_runtime.start_server_daemon().await;
@@ -185,12 +203,12 @@ async fn run_model_with_server(
185203
break;
186204
}
187205
_ => {
188-
let mut remaining_count = 6;
206+
let mut remaining_count = retry_count;
189207
let mut g_reply: String = "".to_owned();
190208
let mut python_code: String = "".to_owned();
191209
loop {
192210
if remaining_count > 0 {
193-
let chat_start = remaining_count == 6;
211+
let chat_start = remaining_count == retry_count;
194212
if let Ok(response) = chat(input, modelname, chat_start, &python_code).await
195213
{
196214
if response.reply.is_empty() {

tiles/src/runtime/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
use crate::runtime::cpu::CPURuntime;
33
use crate::runtime::mlx::MLXRuntime;
44
use anyhow::Result;
5-
use tilekit::modelfile::Modelfile;
65
pub mod cpu;
76
pub mod mlx;
87

98
pub struct RunArgs {
10-
pub modelfile: Modelfile,
9+
pub modelfile_path: Option<String>,
10+
pub retry_count: u32,
11+
// Future flags go here
1112
}
1213

1314
pub enum Runtime {

0 commit comments

Comments
 (0)