From 0fe9bfe32ac83db49e255e1edb4e5c61db8148ee Mon Sep 17 00:00:00 2001 From: Brendan Dahl Date: Thu, 22 Jan 2026 22:28:22 +0000 Subject: [PATCH] Skip downloading transformer models if they already exist. Implement a caching mechanism using a JSON file to track the state of downloaded models. This avoids redundant downloads during prebuild if the model configuration has not changed. A --force flag is also added to allow users to manually bypass the cache and re-download all models when necessary. --- resources/shared/download-cache.mjs | 26 ++++++ .../transformers-js/src/download-models.mjs | 82 ++++++++++++++----- 2 files changed, 88 insertions(+), 20 deletions(-) create mode 100644 resources/shared/download-cache.mjs diff --git a/resources/shared/download-cache.mjs b/resources/shared/download-cache.mjs new file mode 100644 index 0000000..dc4f070 --- /dev/null +++ b/resources/shared/download-cache.mjs @@ -0,0 +1,26 @@ +import fs from 'fs'; + +export default class DownloadCache { + cached = {}; + + constructor(filename, force) { + if (force) { + return; + } + this.filename = filename; + if (fs.existsSync(filename)) { + try { + this.cached = JSON.parse(fs.readFileSync(filename, 'utf8')); + } catch (err) { + console.warn(`Warning: Could not read cache file ${filename}:`, err.message); + } + } + } + has(key) { + return !!this.cached[key]; + } + put(key) { + this.cached[key] = true; + fs.writeFileSync(this.filename, JSON.stringify(this.cached, null, 2)); + } +} diff --git a/resources/transformers-js/src/download-models.mjs b/resources/transformers-js/src/download-models.mjs index bbb6faa..f55a14f 100644 --- a/resources/transformers-js/src/download-models.mjs +++ b/resources/transformers-js/src/download-models.mjs @@ -3,6 +3,7 @@ import { KokoroTTS } from "kokoro-js"; import fs from 'fs'; import path from 'path'; import fetch from 'node-fetch'; +import DownloadCache from '../../shared/download-cache.mjs'; const MODEL_DIR = './models'; env.localModelPath = MODEL_DIR; @@ -48,6 +49,30 @@ const KOKORO_FILES = [ 'tokenizer_config.json', ]; +const MARGO_MODELS_TO_DOWNLOAD = [ + { + class: 'SiglipVisionModel', + dtype: 'bnb4', + }, + { + class: 'AutoImageProcessor', + }, + { + class: 'SiglipTextModel', + dtype: 'bnb4', + }, + { + class: 'AutoTokenizer', + dtype: 'bnb4', + } +]; +const MARGO_NAME_TO_CLASS = { + 'SiglipVisionModel': SiglipVisionModel, + 'AutoImageProcessor': AutoImageProcessor, + 'SiglipTextModel': SiglipTextModel, + 'AutoTokenizer': AutoTokenizer, +}; + function getHuggingFaceUrl(repo, filename, branch = 'main') { if(filename.endsWith('.onnx')) { return `https://huggingface.co/${repo}/resolve/${branch}/onnx/${filename}`; @@ -56,6 +81,8 @@ function getHuggingFaceUrl(repo, filename, branch = 'main') { } async function downloadModels() { + const CACHE_FILE = path.join(MODEL_DIR, 'cache.json'); + const cache = new DownloadCache(CACHE_FILE, process.argv.includes('--force')); if (!fs.existsSync(MODEL_DIR)) { console.log(`Creating directory: ${MODEL_DIR}`); fs.mkdirSync(MODEL_DIR, { recursive: true }); @@ -71,6 +98,12 @@ async function downloadModels() { for (const modelInfo of MODELS_TO_DOWNLOAD) { const { id: modelId, task: modelTask, dtype: modelDType } = modelInfo; + const cacheKey = `${modelId}-${modelTask}-${modelDType}`; + if (cache.has(cacheKey)) { + console.log(`Model ${modelId} (${modelTask}, dtype: ${modelDType}) already cached. Skipping.`); + continue; + } + console.log(`Downloading files for ${modelId} (${modelTask}, dtype: ${modelDType})...`); await pipeline( @@ -82,34 +115,41 @@ async function downloadModels() { }); console.log(`Successfully downloaded and cached ${modelId}`); + cache.put(cacheKey); } // Download Marqo/marqo-fashionSigLIP model - console.log(`Downloading files for Marqo/marqo-fashionSigLIP (zero-shot-image-classification, dtype: bnb4)...`); - await SiglipVisionModel.from_pretrained("Marqo/marqo-fashionSigLIP" ,{ - cache_dir: env.localModelPath, - dtype: 'bnb4' - }); - await AutoImageProcessor.from_pretrained("Marqo/marqo-fashionSigLIP", { - cache_dir: env.localModelPath, - }); - await SiglipTextModel.from_pretrained("Marqo/marqo-fashionSigLIP", { - cache_dir: env.localModelPath, - dtype: 'bnb4' - }); - await AutoTokenizer.from_pretrained("Marqo/marqo-fashionSigLIP", { - cache_dir: env.localModelPath, - }); - console.log(`Successfully downloaded and cached Marqo/marqo-fashionSigLIP`); + console.log(`Checking Marqo/marqo-fashionSigLIP models...`); + for (const modelInfo of MARGO_MODELS_TO_DOWNLOAD) { + const cacheKey = `${modelInfo.class}-${modelInfo.dtype}`; + if (cache.has(cacheKey)) { + console.log(`Model ${modelInfo.class} (dtype: ${modelInfo.dtype}) already cached. Skipping.`); + continue; + } + + console.log(`Downloading Marqo/marqo-fashionSigLIP (${modelInfo.class}${modelInfo.dtype ? `, dtype: ${modelInfo.dtype}` : ''})...`); + await MARGO_NAME_TO_CLASS[modelInfo.class].from_pretrained("Marqo/marqo-fashionSigLIP", { + cache_dir: env.localModelPath, + dtype: modelInfo.dtype + }); + + cache.put(cacheKey); + } + console.log(`Successfully checked Marqo/marqo-fashionSigLIP`); // Download onnx-community/Kokoro-82M-v1.0-ONNX model - console.log(`Starting manual download for ${KOKORO_REPO}...`); + console.log(`Starting manual download check for ${KOKORO_REPO}...`); const kokoroModelPath = path.join(MODEL_DIR, KOKORO_REPO); if (!fs.existsSync(kokoroModelPath)) { fs.mkdirSync(kokoroModelPath, { recursive: true }); } for (const filename of KOKORO_FILES) { + const cacheKey = `${KOKORO_REPO}-${filename}`; + if (cache.has(cacheKey)) { + console.log(` ${filename} already exists, skipping.`); + continue; + } const isOnnxFile = filename.endsWith('.onnx') || filename.endsWith('.onnx_data'); const modelUrl = getHuggingFaceUrl(KOKORO_REPO, filename); let outputPath; @@ -136,11 +176,13 @@ async function downloadModels() { response.body.on('error', reject); fileStream.on('finish', resolve); }); + + cache.put(cacheKey); } catch (err) { console.error(` Failed to download ${filename}:`, err.message); } } - console.log(`Successfully downloaded all files for ${KOKORO_REPO}`); + console.log(`Successfully checked all files for ${KOKORO_REPO}`); } catch (err) { console.error("Model download failed:", err); @@ -151,6 +193,6 @@ async function downloadModels() { } downloadModels().catch(err => { - console.error("Download process terminated."); + console.error("Download process terminated.", err); process.exit(1); -}); \ No newline at end of file +});