Skip to content

Commit fe14a70

Browse files
authored
Moved model downloading ownership to CLI (#38)
* feat: Moved model downloading ownership to cli from py server * refactor: passing CI * fix: using String's ends_with instead of contains * fix: changed the default relay_count to 10 * fix: changed retry_count to relay_count
1 parent 5a32220 commit fe14a70

10 files changed

Lines changed: 517 additions & 42 deletions

File tree

Cargo.lock

Lines changed: 383 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

server/backend/mlx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
logger = logging.getLogger("app")
1414

15-
from typing import Any, Dict, List, Optional, Union
15+
from typing import Any, Dict, Iterator, List, Optional, Union
1616

1717
_model_cache: Dict[str, MLXRunner] = {}
1818
_default_max_tokens: Optional[int] = None # Use dynamic model-aware limits by default
@@ -181,3 +181,4 @@ def format_chat_messages_for_runner(
181181
def count_tokens(text: str) -> int:
182182
"""Rough token count estimation."""
183183
return int(len(text.split()) * 1.3) # Approximation, convert to int
184+

tiles/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ anyhow = "1.0"
1313
tokio = { version = "1" , features = ["macros", "rt-multi-thread"]}
1414
owo-colors = "4"
1515
futures-util = "0.3"
16-
16+
hf-hub = {version = "0.4", features = ["tokio"]}

tiles/src/core/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1+
// to be deprecated and removed, the core stuff will be moved to tilekit sdk
2+
13
pub mod health;

tiles/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
pub mod core;
22
pub mod runtime;
3-
3+
pub mod utils;
44
#[cfg(test)]
55
mod tests {}

tiles/src/main.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::error::Error;
22

33
use clap::{Args, Parser, Subcommand};
4-
use tiles::runtime::{build_runtime, RunArgs};
4+
use tiles::runtime::{RunArgs, build_runtime};
55
mod commands;
66
#[derive(Debug, Parser)]
77
#[command(name = "tiles")]
@@ -31,10 +31,9 @@ enum Commands {
3131

3232
#[derive(Debug, Args)]
3333
struct RunFlags {
34-
/// Number of chat retries before giving up (default: 6)
35-
#[arg(short = 'r', long, default_value_t = 6)]
36-
retry_count: u32,
37-
34+
/// Max times cli communicates with the model until it gets a proper reply for a user prompt
35+
#[arg(short = 'r', long, default_value_t = 10)]
36+
relay_count: u32,
3837
// Future flags go here:
3938
// #[arg(long, default_value_t = 6969)]
4039
// port: u16,
@@ -61,10 +60,13 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
6160
let cli = Cli::parse();
6261
let runtime = build_runtime();
6362
match cli.command {
64-
Commands::Run { modelfile_path, flags } => {
63+
Commands::Run {
64+
modelfile_path,
65+
flags,
66+
} => {
6567
let run_args = RunArgs {
6668
modelfile_path,
67-
retry_count: flags.retry_count,
69+
relay_count: flags.relay_count,
6870
};
6971
commands::run(&runtime, run_args).await;
7072
}

tiles/src/runtime/mlx.rs

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use crate::runtime::RunArgs;
2+
use crate::utils::hf_model_downloader::*;
13
use anyhow::{Context, Result};
24
use futures_util::StreamExt;
35
use owo_colors::OwoColorize;
@@ -12,7 +14,6 @@ use std::{env, fs};
1214
use std::{io, process::Command};
1315
use tilekit::modelfile::Modelfile;
1416
use tokio::time::sleep;
15-
1617
pub struct MLXRuntime {}
1718

1819
impl MLXRuntime {}
@@ -37,7 +38,7 @@ impl MLXRuntime {
3738
const DEFAULT_MODELFILE: &str = "FROM driaforall/mem-agent-mlx-4bit";
3839

3940
// Parse modelfile
40-
let modelfile_parse_result = if let Some(modelfile_str) = run_args.modelfile_path {
41+
let modelfile_parse_result = if let Some(modelfile_str) = &run_args.modelfile_path {
4142
tilekit::modelfile::parse_from_file(modelfile_str.as_str())
4243
} else {
4344
tilekit::modelfile::parse(DEFAULT_MODELFILE)
@@ -53,7 +54,7 @@ impl MLXRuntime {
5354

5455
let model = modelfile.from.as_ref().unwrap();
5556
if model.starts_with("driaforall/mem-agent") {
56-
let _res = run_model_with_server(self, modelfile, run_args.retry_count).await;
57+
let _res = run_model_with_server(self, modelfile, &run_args).await;
5758
} else {
5859
run_model_by_sub_process(modelfile);
5960
}
@@ -172,20 +173,27 @@ fn run_model_by_sub_process(modelfile: Modelfile) {
172173
async fn run_model_with_server(
173174
mlx_runtime: &MLXRuntime,
174175
modelfile: Modelfile,
175-
retry_count: u32,
176+
run_args: &RunArgs,
176177
) -> reqwest::Result<()> {
177178
if !cfg!(debug_assertions) {
178179
let _res = mlx_runtime.start_server_daemon().await;
179180
let _ = wait_until_server_is_up().await;
180181
}
181-
let stdin = io::stdin();
182-
let mut stdout = io::stdout();
183182
// loading the model from mem-agent via daemon server
184183
let memory_path = get_memory_path()
185184
.context("Retrieving memory_path failed")
186185
.unwrap();
187186
let modelname = modelfile.from.as_ref().unwrap();
188-
load_model(modelname, &memory_path).await.unwrap();
187+
match load_model(modelname, &memory_path).await {
188+
Ok(_) => start_repl(mlx_runtime, modelname, run_args).await,
189+
Err(err) => println!("{}", err),
190+
}
191+
Ok(())
192+
}
193+
194+
async fn start_repl(mlx_runtime: &MLXRuntime, modelname: &str, run_args: &RunArgs) {
195+
let stdin = io::stdin();
196+
let mut stdout = io::stdout();
189197
println!("Running in interactive mode");
190198
// TODO: Handle "enter" key press or any key press when repl is processing an input
191199
loop {
@@ -203,12 +211,12 @@ async fn run_model_with_server(
203211
break;
204212
}
205213
_ => {
206-
let mut remaining_count = retry_count;
214+
let mut remaining_count = run_args.relay_count;
207215
let mut g_reply: String = "".to_owned();
208216
let mut python_code: String = "".to_owned();
209217
loop {
210218
if remaining_count > 0 {
211-
let chat_start = remaining_count == retry_count;
219+
let chat_start = remaining_count == run_args.relay_count;
212220
if let Ok(response) = chat(input, modelname, chat_start, &python_code).await
213221
{
214222
if response.reply.is_empty() {
@@ -233,7 +241,6 @@ async fn run_model_with_server(
233241
}
234242
}
235243
}
236-
Ok(())
237244
}
238245

239246
async fn ping() -> Result<(), String> {
@@ -252,6 +259,8 @@ async fn load_model(model_name: &str, memory_path: &str) -> Result<(), String> {
252259
"model": model_name,
253260
"memory_path": memory_path
254261
});
262+
263+
//TODO: Fix the unwrap here
255264
let res = client
256265
.post("http://127.0.0.1:6969/start")
257266
.json(&body)
@@ -260,33 +269,26 @@ async fn load_model(model_name: &str, memory_path: &str) -> Result<(), String> {
260269
.unwrap();
261270
match res.status() {
262271
StatusCode::OK => Ok(()),
263-
StatusCode::NOT_FOUND => download_model(model_name).await,
272+
StatusCode::NOT_FOUND => {
273+
println!("Downloading {}\n", model_name);
274+
match pull_model(model_name).await {
275+
Ok(_) => {
276+
println!("\nDownloading completed \n");
277+
Ok(())
278+
}
279+
Err(err) => Err(err),
280+
}
281+
}
264282
_ => {
265283
println!("err {:?}", res);
266-
Ok(())
284+
Err(format!(
285+
"Failed to load model {} due to {:?}",
286+
model_name, res
287+
))
267288
}
268289
}
269290
}
270291

271-
async fn download_model(model_name: &str) -> Result<(), String> {
272-
println!("Downloading the model {} ....", model_name);
273-
let client = Client::new();
274-
let body = json!({
275-
"model": model_name
276-
});
277-
let res = client
278-
.post("http://127.0.0.1:6969/download")
279-
.json(&body)
280-
.send()
281-
.await
282-
.unwrap();
283-
if res.status() == 200 {
284-
Ok(())
285-
} else {
286-
Err(String::from("Downloading model failed"))
287-
}
288-
}
289-
290292
async fn chat(
291293
input: &str,
292294
model_name: &str,

tiles/src/runtime/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ pub mod mlx;
77

88
pub struct RunArgs {
99
pub modelfile_path: Option<String>,
10-
pub retry_count: u32,
10+
pub relay_count: u32,
1111
// Future flags go here
1212
}
1313

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/// Manages model snapshot downloading from HuggingFace
2+
use std::{env, path::PathBuf};
3+
4+
use hf_hub::api::{
5+
Siblings,
6+
tokio::{ApiBuilder, ApiError},
7+
};
8+
9+
/// Download the entire model (including snapshot) for the given model name
10+
pub async fn pull_model(model_name: &str) -> Result<(), String> {
11+
snapshot_download(model_name).await
12+
}
13+
14+
pub async fn snapshot_download(modelname: &str) -> Result<(), String> {
15+
let allow_patterns = [
16+
".json",
17+
".txt",
18+
".safetensors",
19+
".md",
20+
".gitattributes",
21+
"LICENSE",
22+
];
23+
let api_build_result = ApiBuilder::new()
24+
.with_progress(true)
25+
.with_cache_dir(PathBuf::from(get_model_cache()))
26+
.build();
27+
28+
match api_build_result {
29+
Ok(api) => {
30+
let repo = api.model(modelname.to_owned());
31+
match repo.info().await {
32+
Ok(repo_info) => {
33+
let filtered_siblings = repo_info
34+
.siblings
35+
.iter()
36+
.filter(|sibling| {
37+
allow_patterns
38+
.iter()
39+
.any(|pat| sibling.rfilename.ends_with(pat))
40+
})
41+
.collect::<Vec<&Siblings>>();
42+
43+
for sibling in filtered_siblings {
44+
if repo.get(&sibling.rfilename).await.is_err() {
45+
return Err(format!(
46+
"{:?} failed to download, retry again",
47+
&sibling.rfilename,
48+
));
49+
}
50+
}
51+
}
52+
Err(err) => return Err(format_hf_api_error(err)),
53+
};
54+
}
55+
Err(err) => return Err(format_hf_api_error(err)),
56+
}
57+
58+
Ok(())
59+
}
60+
61+
fn format_hf_api_error(api_error: ApiError) -> String {
62+
match api_error {
63+
ApiError::RequestError(err) => err.to_string(),
64+
ApiError::TooManyRetries(err) => err.to_string(),
65+
_err => "Something unexpected happened, check your internet connection".to_owned(),
66+
}
67+
}
68+
69+
fn get_model_cache() -> String {
70+
let default_cache = format!(
71+
"{}/.cache/huggingface",
72+
env::home_dir().unwrap().to_str().unwrap()
73+
);
74+
let cache_root = if let Ok(home) = env::var("HF_HOME") {
75+
home.to_owned()
76+
} else {
77+
default_cache
78+
};
79+
80+
format!("{}/hub", cache_root)
81+
}
82+
83+
#[cfg(test)]
84+
mod tests {}

tiles/src/utils/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub mod hf_model_downloader;

0 commit comments

Comments
 (0)