@@ -3,6 +3,7 @@ import { KokoroTTS } from "kokoro-js";
33import fs from 'fs' ;
44import path from 'path' ;
55import fetch from 'node-fetch' ;
6+ import DownloadCache from '../../shared/download-cache.mjs' ;
67
78const MODEL_DIR = './models' ;
89env . localModelPath = MODEL_DIR ;
@@ -48,6 +49,29 @@ const KOKORO_FILES = [
4849 'tokenizer_config.json' ,
4950] ;
5051
52+ const MARGO_MODELS_TO_DOWNLOAD = [
53+ {
54+ class : 'SiglipVisionModel' ,
55+ dtype : 'bnb4' ,
56+ } ,
57+ {
58+ class : 'AutoImageProcessor' ,
59+ } ,
60+ {
61+ class : 'SiglipTextModel' ,
62+ dtype : 'bnb4' ,
63+ } ,
64+ {
65+ class : 'AutoTokenizer' ,
66+ }
67+ ] ;
68+ const MARGO_NAME_TO_CLASS = {
69+ 'SiglipVisionModel' : SiglipVisionModel ,
70+ 'AutoImageProcessor' : AutoImageProcessor ,
71+ 'SiglipTextModel' : SiglipTextModel ,
72+ 'AutoTokenizer' : AutoTokenizer ,
73+ } ;
74+
5175function getHuggingFaceUrl ( repo , filename , branch = 'main' ) {
5276 if ( filename . endsWith ( '.onnx' ) ) {
5377 return `https://huggingface.co/${ repo } /resolve/${ branch } /onnx/${ filename } ` ;
@@ -56,6 +80,8 @@ function getHuggingFaceUrl(repo, filename, branch = 'main') {
5680}
5781
5882async function downloadModels ( ) {
83+ const CACHE_FILE = path . join ( MODEL_DIR , 'cache.json' ) ;
84+ const cache = new DownloadCache ( CACHE_FILE , process . argv . includes ( '--force' ) ) ;
5985 if ( ! fs . existsSync ( MODEL_DIR ) ) {
6086 console . log ( `Creating directory: ${ MODEL_DIR } ` ) ;
6187 fs . mkdirSync ( MODEL_DIR , { recursive : true } ) ;
@@ -71,6 +97,12 @@ async function downloadModels() {
7197 for ( const modelInfo of MODELS_TO_DOWNLOAD ) {
7298 const { id : modelId , task : modelTask , dtype : modelDType } = modelInfo ;
7399
100+ const cacheKey = `${ modelId } -${ modelTask } -${ modelDType } ` ;
101+ if ( cache . has ( cacheKey ) ) {
102+ console . log ( `Model ${ modelId } (${ modelTask } , dtype: ${ modelDType } ) already cached. Skipping.` ) ;
103+ continue ;
104+ }
105+
74106 console . log ( `Downloading files for ${ modelId } (${ modelTask } , dtype: ${ modelDType } )...` ) ;
75107
76108 await pipeline (
@@ -82,34 +114,41 @@ async function downloadModels() {
82114 } ) ;
83115
84116 console . log ( `Successfully downloaded and cached ${ modelId } ` ) ;
117+ cache . put ( cacheKey ) ;
85118 }
86119
87120 // Download Marqo/marqo-fashionSigLIP model
88- console . log ( `Downloading files for Marqo/marqo-fashionSigLIP (zero-shot-image-classification, dtype: bnb4)...` ) ;
89- await SiglipVisionModel . from_pretrained ( "Marqo/marqo-fashionSigLIP" , {
90- cache_dir : env . localModelPath ,
91- dtype : 'bnb4'
92- } ) ;
93- await AutoImageProcessor . from_pretrained ( "Marqo/marqo-fashionSigLIP" , {
94- cache_dir : env . localModelPath ,
95- } ) ;
96- await SiglipTextModel . from_pretrained ( "Marqo/marqo-fashionSigLIP" , {
97- cache_dir : env . localModelPath ,
98- dtype : 'bnb4'
99- } ) ;
100- await AutoTokenizer . from_pretrained ( "Marqo/marqo-fashionSigLIP" , {
101- cache_dir : env . localModelPath ,
102- } ) ;
103- console . log ( `Successfully downloaded and cached Marqo/marqo-fashionSigLIP` ) ;
121+ console . log ( `Checking Marqo/marqo-fashionSigLIP models...` ) ;
122+ for ( const modelInfo of MARGO_MODELS_TO_DOWNLOAD ) {
123+ const cacheKey = `${ modelInfo . class } -${ modelInfo . dtype } ` ;
124+ if ( cache . has ( cacheKey ) ) {
125+ console . log ( `Model ${ modelInfo . class } (dtype: ${ modelInfo . dtype } ) already cached. Skipping.` ) ;
126+ continue ;
127+ }
128+
129+ console . log ( `Downloading Marqo/marqo-fashionSigLIP (${ modelInfo . class } ${ modelInfo . dtype ? `, dtype: ${ modelInfo . dtype } ` : '' } )...` ) ;
130+ await MARGO_NAME_TO_CLASS [ modelInfo . class ] . from_pretrained ( "Marqo/marqo-fashionSigLIP" , {
131+ cache_dir : env . localModelPath ,
132+ dtype : modelInfo . dtype
133+ } ) ;
134+
135+ cache . put ( cacheKey ) ;
136+ }
137+ console . log ( `Successfully checked Marqo/marqo-fashionSigLIP` ) ;
104138
105139 // Download onnx-community/Kokoro-82M-v1.0-ONNX model
106- console . log ( `Starting manual download for ${ KOKORO_REPO } ...` ) ;
140+ console . log ( `Starting manual download check for ${ KOKORO_REPO } ...` ) ;
107141 const kokoroModelPath = path . join ( MODEL_DIR , KOKORO_REPO ) ;
108142 if ( ! fs . existsSync ( kokoroModelPath ) ) {
109143 fs . mkdirSync ( kokoroModelPath , { recursive : true } ) ;
110144 }
111145
112146 for ( const filename of KOKORO_FILES ) {
147+ const cacheKey = `${ KOKORO_REPO } -${ filename } ` ;
148+ if ( cache . has ( cacheKey ) ) {
149+ console . log ( ` ${ filename } already exists, skipping.` ) ;
150+ continue ;
151+ }
113152 const isOnnxFile = filename . endsWith ( '.onnx' ) || filename . endsWith ( '.onnx_data' ) ;
114153 const modelUrl = getHuggingFaceUrl ( KOKORO_REPO , filename ) ;
115154 let outputPath ;
@@ -136,11 +175,13 @@ async function downloadModels() {
136175 response . body . on ( 'error' , reject ) ;
137176 fileStream . on ( 'finish' , resolve ) ;
138177 } ) ;
178+
179+ cache . put ( cacheKey ) ;
139180 } catch ( err ) {
140181 console . error ( ` Failed to download ${ filename } :` , err . message ) ;
141182 }
142183 }
143- console . log ( `Successfully downloaded all files for ${ KOKORO_REPO } ` ) ;
184+ console . log ( `Successfully checked all files for ${ KOKORO_REPO } ` ) ;
144185
145186 } catch ( err ) {
146187 console . error ( "Model download failed:" , err ) ;
@@ -151,6 +192,6 @@ async function downloadModels() {
151192}
152193
153194downloadModels ( ) . catch ( err => {
154- console . error ( "Download process terminated." ) ;
195+ console . error ( "Download process terminated." , err ) ;
155196 process . exit ( 1 ) ;
156- } ) ;
197+ } ) ;
0 commit comments