1+ use crate :: runtime:: RunArgs ;
2+ use crate :: utils:: hf_model_downloader:: * ;
13use anyhow:: { Context , Result } ;
24use futures_util:: StreamExt ;
35use owo_colors:: OwoColorize ;
@@ -12,7 +14,6 @@ use std::{env, fs};
1214use std:: { io, process:: Command } ;
1315use tilekit:: modelfile:: Modelfile ;
1416use tokio:: time:: sleep;
15-
1617pub struct MLXRuntime { }
1718
1819impl MLXRuntime { }
@@ -37,7 +38,7 @@ impl MLXRuntime {
3738 const DEFAULT_MODELFILE : & str = "FROM driaforall/mem-agent-mlx-4bit" ;
3839
3940 // Parse modelfile
40- let modelfile_parse_result = if let Some ( modelfile_str) = run_args. modelfile_path {
41+ let modelfile_parse_result = if let Some ( modelfile_str) = & run_args. modelfile_path {
4142 tilekit:: modelfile:: parse_from_file ( modelfile_str. as_str ( ) )
4243 } else {
4344 tilekit:: modelfile:: parse ( DEFAULT_MODELFILE )
@@ -53,7 +54,7 @@ impl MLXRuntime {
5354
5455 let model = modelfile. from . as_ref ( ) . unwrap ( ) ;
5556 if model. starts_with ( "driaforall/mem-agent" ) {
56- let _res = run_model_with_server ( self , modelfile, run_args. retry_count ) . await ;
57+ let _res = run_model_with_server ( self , modelfile, & run_args) . await ;
5758 } else {
5859 run_model_by_sub_process ( modelfile) ;
5960 }
@@ -172,20 +173,27 @@ fn run_model_by_sub_process(modelfile: Modelfile) {
172173async fn run_model_with_server (
173174 mlx_runtime : & MLXRuntime ,
174175 modelfile : Modelfile ,
175- retry_count : u32 ,
176+ run_args : & RunArgs ,
176177) -> reqwest:: Result < ( ) > {
177178 if !cfg ! ( debug_assertions) {
178179 let _res = mlx_runtime. start_server_daemon ( ) . await ;
179180 let _ = wait_until_server_is_up ( ) . await ;
180181 }
181- let stdin = io:: stdin ( ) ;
182- let mut stdout = io:: stdout ( ) ;
183182 // loading the model from mem-agent via daemon server
184183 let memory_path = get_memory_path ( )
185184 . context ( "Retrieving memory_path failed" )
186185 . unwrap ( ) ;
187186 let modelname = modelfile. from . as_ref ( ) . unwrap ( ) ;
188- load_model ( modelname, & memory_path) . await . unwrap ( ) ;
187+ match load_model ( modelname, & memory_path) . await {
188+ Ok ( _) => start_repl ( mlx_runtime, modelname, run_args) . await ,
189+ Err ( err) => println ! ( "{}" , err) ,
190+ }
191+ Ok ( ( ) )
192+ }
193+
194+ async fn start_repl ( mlx_runtime : & MLXRuntime , modelname : & str , run_args : & RunArgs ) {
195+ let stdin = io:: stdin ( ) ;
196+ let mut stdout = io:: stdout ( ) ;
189197 println ! ( "Running in interactive mode" ) ;
190198 // TODO: Handle "enter" key press or any key press when repl is processing an input
191199 loop {
@@ -203,12 +211,12 @@ async fn run_model_with_server(
203211 break ;
204212 }
205213 _ => {
206- let mut remaining_count = retry_count ;
214+ let mut remaining_count = run_args . relay_count ;
207215 let mut g_reply: String = "" . to_owned ( ) ;
208216 let mut python_code: String = "" . to_owned ( ) ;
209217 loop {
210218 if remaining_count > 0 {
211- let chat_start = remaining_count == retry_count ;
219+ let chat_start = remaining_count == run_args . relay_count ;
212220 if let Ok ( response) = chat ( input, modelname, chat_start, & python_code) . await
213221 {
214222 if response. reply . is_empty ( ) {
@@ -233,7 +241,6 @@ async fn run_model_with_server(
233241 }
234242 }
235243 }
236- Ok ( ( ) )
237244}
238245
239246async fn ping ( ) -> Result < ( ) , String > {
@@ -252,6 +259,8 @@ async fn load_model(model_name: &str, memory_path: &str) -> Result<(), String> {
252259 "model" : model_name,
253260 "memory_path" : memory_path
254261 } ) ;
262+
263+ //TODO: Fix the unwrap here
255264 let res = client
256265 . post ( "http://127.0.0.1:6969/start" )
257266 . json ( & body)
@@ -260,33 +269,26 @@ async fn load_model(model_name: &str, memory_path: &str) -> Result<(), String> {
260269 . unwrap ( ) ;
261270 match res. status ( ) {
262271 StatusCode :: OK => Ok ( ( ) ) ,
263- StatusCode :: NOT_FOUND => download_model ( model_name) . await ,
272+ StatusCode :: NOT_FOUND => {
273+ println ! ( "Downloading {}\n " , model_name) ;
274+ match pull_model ( model_name) . await {
275+ Ok ( _) => {
276+ println ! ( "\n Downloading completed \n " ) ;
277+ Ok ( ( ) )
278+ }
279+ Err ( err) => Err ( err) ,
280+ }
281+ }
264282 _ => {
265283 println ! ( "err {:?}" , res) ;
266- Ok ( ( ) )
284+ Err ( format ! (
285+ "Failed to load model {} due to {:?}" ,
286+ model_name, res
287+ ) )
267288 }
268289 }
269290}
270291
271- async fn download_model ( model_name : & str ) -> Result < ( ) , String > {
272- println ! ( "Downloading the model {} ...." , model_name) ;
273- let client = Client :: new ( ) ;
274- let body = json ! ( {
275- "model" : model_name
276- } ) ;
277- let res = client
278- . post ( "http://127.0.0.1:6969/download" )
279- . json ( & body)
280- . send ( )
281- . await
282- . unwrap ( ) ;
283- if res. status ( ) == 200 {
284- Ok ( ( ) )
285- } else {
286- Err ( String :: from ( "Downloading model failed" ) )
287- }
288- }
289-
290292async fn chat (
291293 input : & str ,
292294 model_name : & str ,
0 commit comments