@@ -3,6 +3,7 @@ import { KokoroTTS } from "kokoro-js";
33import fs from 'fs' ;
44import path from 'path' ;
55import fetch from 'node-fetch' ;
6+ import DownloadCache from '../../shared/download-cache.mjs' ;
67
78const MODEL_DIR = './models' ;
89env . localModelPath = MODEL_DIR ;
@@ -48,6 +49,30 @@ const KOKORO_FILES = [
4849 'tokenizer_config.json' ,
4950] ;
5051
52+ const MARGO_MODELS_TO_DOWNLOAD = [
53+ {
54+ class : 'SiglipVisionModel' ,
55+ dtype : 'bnb4' ,
56+ } ,
57+ {
58+ class : 'AutoImageProcessor' ,
59+ } ,
60+ {
61+ class : 'SiglipTextModel' ,
62+ dtype : 'bnb4' ,
63+ } ,
64+ {
65+ class : 'AutoTokenizer' ,
66+ dtype : 'bnb4' ,
67+ }
68+ ] ;
69+ const MARGO_NAME_TO_CLASS = {
70+ 'SiglipVisionModel' : SiglipVisionModel ,
71+ 'AutoImageProcessor' : AutoImageProcessor ,
72+ 'SiglipTextModel' : SiglipTextModel ,
73+ 'AutoTokenizer' : AutoTokenizer ,
74+ } ;
75+
5176function getHuggingFaceUrl ( repo , filename , branch = 'main' ) {
5277 if ( filename . endsWith ( '.onnx' ) ) {
5378 return `https://huggingface.co/${ repo } /resolve/${ branch } /onnx/${ filename } ` ;
@@ -56,6 +81,8 @@ function getHuggingFaceUrl(repo, filename, branch = 'main') {
5681}
5782
5883async function downloadModels ( ) {
84+ const CACHE_FILE = path . join ( MODEL_DIR , 'cache.json' ) ;
85+ const cache = new DownloadCache ( CACHE_FILE , process . argv . includes ( '--force' ) ) ;
5986 if ( ! fs . existsSync ( MODEL_DIR ) ) {
6087 console . log ( `Creating directory: ${ MODEL_DIR } ` ) ;
6188 fs . mkdirSync ( MODEL_DIR , { recursive : true } ) ;
@@ -71,6 +98,12 @@ async function downloadModels() {
7198 for ( const modelInfo of MODELS_TO_DOWNLOAD ) {
7299 const { id : modelId , task : modelTask , dtype : modelDType } = modelInfo ;
73100
101+ const cacheKey = `${ modelId } -${ modelTask } -${ modelDType } ` ;
102+ if ( cache . has ( cacheKey ) ) {
103+ console . log ( `Model ${ modelId } (${ modelTask } , dtype: ${ modelDType } ) already cached. Skipping.` ) ;
104+ continue ;
105+ }
106+
74107 console . log ( `Downloading files for ${ modelId } (${ modelTask } , dtype: ${ modelDType } )...` ) ;
75108
76109 await pipeline (
@@ -82,34 +115,41 @@ async function downloadModels() {
82115 } ) ;
83116
84117 console . log ( `Successfully downloaded and cached ${ modelId } ` ) ;
118+ cache . put ( cacheKey ) ;
85119 }
86120
87121 // Download Marqo/marqo-fashionSigLIP model
88- console . log ( `Downloading files for Marqo/marqo-fashionSigLIP (zero-shot-image-classification, dtype: bnb4)...` ) ;
89- await SiglipVisionModel . from_pretrained ( "Marqo/marqo-fashionSigLIP" , {
90- cache_dir : env . localModelPath ,
91- dtype : 'bnb4'
92- } ) ;
93- await AutoImageProcessor . from_pretrained ( "Marqo/marqo-fashionSigLIP" , {
94- cache_dir : env . localModelPath ,
95- } ) ;
96- await SiglipTextModel . from_pretrained ( "Marqo/marqo-fashionSigLIP" , {
97- cache_dir : env . localModelPath ,
98- dtype : 'bnb4'
99- } ) ;
100- await AutoTokenizer . from_pretrained ( "Marqo/marqo-fashionSigLIP" , {
101- cache_dir : env . localModelPath ,
102- } ) ;
103- console . log ( `Successfully downloaded and cached Marqo/marqo-fashionSigLIP` ) ;
122+ console . log ( `Checking Marqo/marqo-fashionSigLIP models...` ) ;
123+ for ( const modelInfo of MARGO_MODELS_TO_DOWNLOAD ) {
124+ const cacheKey = `${ modelInfo . class } -${ modelInfo . dtype } ` ;
125+ if ( cache . has ( cacheKey ) ) {
126+ console . log ( `Model ${ modelInfo . class } (dtype: ${ modelInfo . dtype } ) already cached. Skipping.` ) ;
127+ continue ;
128+ }
129+
130+ console . log ( `Downloading Marqo/marqo-fashionSigLIP (${ modelInfo . class } ${ modelInfo . dtype ? `, dtype: ${ modelInfo . dtype } ` : '' } )...` ) ;
131+ await MARGO_NAME_TO_CLASS [ modelInfo . class ] . from_pretrained ( "Marqo/marqo-fashionSigLIP" , {
132+ cache_dir : env . localModelPath ,
133+ dtype : modelInfo . dtype
134+ } ) ;
135+
136+ cache . put ( cacheKey ) ;
137+ }
138+ console . log ( `Successfully checked Marqo/marqo-fashionSigLIP` ) ;
104139
105140 // Download onnx-community/Kokoro-82M-v1.0-ONNX model
106- console . log ( `Starting manual download for ${ KOKORO_REPO } ...` ) ;
141+ console . log ( `Starting manual download check for ${ KOKORO_REPO } ...` ) ;
107142 const kokoroModelPath = path . join ( MODEL_DIR , KOKORO_REPO ) ;
108143 if ( ! fs . existsSync ( kokoroModelPath ) ) {
109144 fs . mkdirSync ( kokoroModelPath , { recursive : true } ) ;
110145 }
111146
112147 for ( const filename of KOKORO_FILES ) {
148+ const cacheKey = `${ KOKORO_REPO } -${ filename } ` ;
149+ if ( cache . has ( cacheKey ) ) {
150+ console . log ( ` ${ filename } already exists, skipping.` ) ;
151+ continue ;
152+ }
113153 const isOnnxFile = filename . endsWith ( '.onnx' ) || filename . endsWith ( '.onnx_data' ) ;
114154 const modelUrl = getHuggingFaceUrl ( KOKORO_REPO , filename ) ;
115155 let outputPath ;
@@ -136,11 +176,13 @@ async function downloadModels() {
136176 response . body . on ( 'error' , reject ) ;
137177 fileStream . on ( 'finish' , resolve ) ;
138178 } ) ;
179+
180+ cache . put ( cacheKey ) ;
139181 } catch ( err ) {
140182 console . error ( ` Failed to download ${ filename } :` , err . message ) ;
141183 }
142184 }
143- console . log ( `Successfully downloaded all files for ${ KOKORO_REPO } ` ) ;
185+ console . log ( `Successfully checked all files for ${ KOKORO_REPO } ` ) ;
144186
145187 } catch ( err ) {
146188 console . error ( "Model download failed:" , err ) ;
@@ -151,6 +193,6 @@ async function downloadModels() {
151193}
152194
153195downloadModels ( ) . catch ( err => {
154- console . error ( "Download process terminated." ) ;
196+ console . error ( "Download process terminated." , err ) ;
155197 process . exit ( 1 ) ;
156- } ) ;
198+ } ) ;
0 commit comments