1+ use std:: collections:: HashMap ;
12use std:: fs;
23use std:: path:: PathBuf ;
4+ use std:: sync:: { Arc , Mutex } ;
35use std:: time:: Instant ;
46
57use futures_util:: StreamExt ;
@@ -8,7 +10,7 @@ use tauri::State;
810use crate :: state:: { AppState , DownloadProgress } ;
911
1012fn models_dir ( comfy_path : & Option < String > , subfolder : & str ) -> Result < PathBuf , String > {
11- let base = comfy_path. as_ref ( ) . ok_or ( "ComfyUI path not set" ) ?;
13+ let base = comfy_path. as_ref ( ) . ok_or ( "ComfyUI path not set. Please set it in settings or install ComfyUI first. " ) ?;
1214 let dir = PathBuf :: from ( base) . join ( "models" ) . join ( subfolder) ;
1315 fs:: create_dir_all ( & dir) . map_err ( |e| format ! ( "Create models dir: {}" , e) ) ?;
1416 Ok ( dir)
@@ -48,23 +50,46 @@ pub async fn download_model(
4850 } ) ;
4951 }
5052
53+ // Clone the Arc so the spawned task can update progress
54+ let downloads_arc = Arc :: clone ( & state. downloads ) ;
5155 let id_clone = id. clone ( ) ;
5256 let filename_clone = filename. clone ( ) ;
5357
5458 tokio:: spawn ( async move {
55- match do_download ( & url, & dest_file) . await {
56- Ok ( _) => println ! ( "[Download] Complete: {}" , filename_clone) ,
57- Err ( e) => println ! ( "[Download] Failed: {} - {}" , filename_clone, e) ,
59+ match do_download ( & url, & dest_file, & downloads_arc, & id_clone) . await {
60+ Ok ( _) => {
61+ if let Ok ( mut dl) = downloads_arc. lock ( ) {
62+ if let Some ( p) = dl. get_mut ( & id_clone) {
63+ p. status = "complete" . to_string ( ) ;
64+ }
65+ }
66+ println ! ( "[Download] Complete: {}" , filename_clone) ;
67+ }
68+ Err ( e) => {
69+ if let Ok ( mut dl) = downloads_arc. lock ( ) {
70+ if let Some ( p) = dl. get_mut ( & id_clone) {
71+ p. status = "error" . to_string ( ) ;
72+ p. error = Some ( e. clone ( ) ) ;
73+ }
74+ }
75+ println ! ( "[Download] Failed: {} - {}" , filename_clone, e) ;
76+ }
5877 }
5978 } ) ;
6079
6180 Ok ( serde_json:: json!( { "status" : "started" , "id" : id} ) )
6281}
6382
64- async fn do_download ( url : & str , dest : & PathBuf ) -> Result < ( ) , String > {
83+ async fn do_download (
84+ url : & str ,
85+ dest : & PathBuf ,
86+ downloads : & Arc < Mutex < HashMap < String , DownloadProgress > > > ,
87+ id : & str ,
88+ ) -> Result < ( ) , String > {
6589 let client = reqwest:: Client :: builder ( )
66- . user_agent ( "LocallyUncensored/1.3 " )
90+ . user_agent ( "LocallyUncensored/1.5 " )
6791 . redirect ( reqwest:: redirect:: Policy :: limited ( 10 ) )
92+ . timeout ( std:: time:: Duration :: from_secs ( 7200 ) ) // 2 hours for large models
6893 . build ( )
6994 . map_err ( |e| e. to_string ( ) ) ?;
7095
@@ -79,6 +104,14 @@ async fn do_download(url: &str, dest: &PathBuf) -> Result<(), String> {
79104
80105 let total = response. content_length ( ) . unwrap_or ( 0 ) ;
81106
107+ // Update total size
108+ if let Ok ( mut dl) = downloads. lock ( ) {
109+ if let Some ( p) = dl. get_mut ( id) {
110+ p. total = total;
111+ p. status = "downloading" . to_string ( ) ;
112+ }
113+ }
114+
82115 let tmp_path = dest. with_extension ( "download" ) ;
83116 let mut file = tokio:: fs:: File :: create ( & tmp_path)
84117 . await
@@ -87,21 +120,26 @@ async fn do_download(url: &str, dest: &PathBuf) -> Result<(), String> {
87120 let mut stream = response. bytes_stream ( ) ;
88121 let mut downloaded: u64 = 0 ;
89122 let start = Instant :: now ( ) ;
123+ let mut last_update = Instant :: now ( ) ;
90124
91125 use tokio:: io:: AsyncWriteExt ;
92126 while let Some ( chunk) = stream. next ( ) . await {
93127 let chunk = chunk. map_err ( |e| format ! ( "Stream error: {}" , e) ) ?;
94128 file. write_all ( & chunk) . await . map_err ( |e| format ! ( "Write: {}" , e) ) ?;
95129 downloaded += chunk. len ( ) as u64 ;
96130
97- // Log progress every ~1MB
98- if downloaded % ( 1024 * 1024 ) < chunk. len ( ) as u64 {
131+ // Update progress every 500ms
132+ if last_update. elapsed ( ) . as_millis ( ) > 500 {
133+ last_update = Instant :: now ( ) ;
99134 let elapsed = start. elapsed ( ) . as_secs_f64 ( ) ;
100135 let speed = if elapsed > 0.0 { downloaded as f64 / elapsed } else { 0.0 } ;
101- println ! ( "[Download] {:.1} MB / {:.1} MB ({:.1} MB/s)" ,
102- downloaded as f64 / 1048576.0 ,
103- total as f64 / 1048576.0 ,
104- speed / 1048576.0 ) ;
136+
137+ if let Ok ( mut dl) = downloads. lock ( ) {
138+ if let Some ( p) = dl. get_mut ( id) {
139+ p. progress = downloaded;
140+ p. speed = speed;
141+ }
142+ }
105143 }
106144 }
107145
@@ -112,12 +150,21 @@ async fn do_download(url: &str, dest: &PathBuf) -> Result<(), String> {
112150 . await
113151 . map_err ( |e| format ! ( "Rename: {}" , e) ) ?;
114152
153+ // Final progress update
154+ if let Ok ( mut dl) = downloads. lock ( ) {
155+ if let Some ( p) = dl. get_mut ( id) {
156+ p. progress = downloaded;
157+ p. total = downloaded;
158+ p. status = "complete" . to_string ( ) ;
159+ }
160+ }
161+
115162 Ok ( ( ) )
116163}
117164
118165#[ tauri:: command]
119166pub fn download_progress ( state : State < ' _ , AppState > ) -> Result < serde_json:: Value , String > {
120167 let downloads = state. downloads . lock ( ) . unwrap ( ) ;
121- let map: std :: collections :: HashMap < String , DownloadProgress > = downloads. clone ( ) ;
168+ let map: HashMap < String , DownloadProgress > = downloads. clone ( ) ;
122169 Ok ( serde_json:: to_value ( map) . unwrap_or_default ( ) )
123170}
0 commit comments