@@ -34,11 +34,28 @@ impl MLXRuntime {
3434 }
3535
3636 pub async fn run ( & self , run_args : super :: RunArgs ) {
37- let model = run_args. modelfile . from . as_ref ( ) . unwrap ( ) ;
37+ const DEFAULT_MODELFILE : & str = "FROM driaforall/mem-agent-mlx-4bit" ;
38+
39+ // Parse modelfile
40+ let modelfile_parse_result = if let Some ( modelfile_str) = run_args. modelfile_path {
41+ tilekit:: modelfile:: parse_from_file ( modelfile_str. as_str ( ) )
42+ } else {
43+ tilekit:: modelfile:: parse ( DEFAULT_MODELFILE )
44+ } ;
45+
46+ let modelfile = match modelfile_parse_result {
47+ Ok ( mf) => mf,
48+ Err ( _err) => {
49+ println ! ( "Invalid Modelfile" ) ;
50+ return ;
51+ }
52+ } ;
53+
54+ let model = modelfile. from . as_ref ( ) . unwrap ( ) ;
3855 if model. starts_with ( "driaforall/mem-agent" ) {
39- let _res = run_model_with_server ( self , run_args. modelfile ) . await ;
56+ let _res = run_model_with_server ( self , modelfile , run_args. retry_count ) . await ;
4057 } else {
41- run_model_by_sub_process ( run_args . modelfile ) ;
58+ run_model_by_sub_process ( modelfile) ;
4259 }
4360 }
4461
@@ -155,6 +172,7 @@ fn run_model_by_sub_process(modelfile: Modelfile) {
155172async fn run_model_with_server (
156173 mlx_runtime : & MLXRuntime ,
157174 modelfile : Modelfile ,
175+ retry_count : u32 ,
158176) -> reqwest:: Result < ( ) > {
159177 if !cfg ! ( debug_assertions) {
160178 let _res = mlx_runtime. start_server_daemon ( ) . await ;
@@ -185,12 +203,12 @@ async fn run_model_with_server(
185203 break ;
186204 }
187205 _ => {
188- let mut remaining_count = 6 ;
206+ let mut remaining_count = retry_count ;
189207 let mut g_reply: String = "" . to_owned ( ) ;
190208 let mut python_code: String = "" . to_owned ( ) ;
191209 loop {
192210 if remaining_count > 0 {
193- let chat_start = remaining_count == 6 ;
211+ let chat_start = remaining_count == retry_count ;
194212 if let Ok ( response) = chat ( input, modelname, chat_start, & python_code) . await
195213 {
196214 if response. reply . is_empty ( ) {
0 commit comments