Skip to content

Commit 0fe9bfe

Browse files
committed
Skip downloading transformer models if they already exist.
Implement a caching mechanism using a JSON file to track the state of downloaded models. This avoids redundant downloads during prebuild if the model configuration has not changed. A --force flag is also added to allow users to manually bypass the cache and re-download all models when necessary.
1 parent 5f8d674 commit 0fe9bfe

File tree

2 files changed

+88
-20
lines changed

2 files changed

+88
-20
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import fs from 'fs';
2+
3+
export default class DownloadCache {
4+
cached = {};
5+
6+
constructor(filename, force) {
7+
if (force) {
8+
return;
9+
}
10+
this.filename = filename;
11+
if (fs.existsSync(filename)) {
12+
try {
13+
this.cached = JSON.parse(fs.readFileSync(filename, 'utf8'));
14+
} catch (err) {
15+
console.warn(`Warning: Could not read cache file ${filename}:`, err.message);
16+
}
17+
}
18+
}
19+
has(key) {
20+
return !!this.cached[key];
21+
}
22+
put(key) {
23+
this.cached[key] = true;
24+
fs.writeFileSync(this.filename, JSON.stringify(this.cached, null, 2));
25+
}
26+
}

resources/transformers-js/src/download-models.mjs

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { KokoroTTS } from "kokoro-js";
33
import fs from 'fs';
44
import path from 'path';
55
import fetch from 'node-fetch';
6+
import DownloadCache from '../../shared/download-cache.mjs';
67

78
const MODEL_DIR = './models';
89
env.localModelPath = MODEL_DIR;
@@ -48,6 +49,30 @@ const KOKORO_FILES = [
4849
'tokenizer_config.json',
4950
];
5051

52+
const MARGO_MODELS_TO_DOWNLOAD = [
53+
{
54+
class: 'SiglipVisionModel',
55+
dtype: 'bnb4',
56+
},
57+
{
58+
class: 'AutoImageProcessor',
59+
},
60+
{
61+
class: 'SiglipTextModel',
62+
dtype: 'bnb4',
63+
},
64+
{
65+
class: 'AutoTokenizer',
66+
dtype: 'bnb4',
67+
}
68+
];
69+
const MARGO_NAME_TO_CLASS = {
70+
'SiglipVisionModel': SiglipVisionModel,
71+
'AutoImageProcessor': AutoImageProcessor,
72+
'SiglipTextModel': SiglipTextModel,
73+
'AutoTokenizer': AutoTokenizer,
74+
};
75+
5176
function getHuggingFaceUrl(repo, filename, branch = 'main') {
5277
if(filename.endsWith('.onnx')) {
5378
return `https://huggingface.co/${repo}/resolve/${branch}/onnx/${filename}`;
@@ -56,6 +81,8 @@ function getHuggingFaceUrl(repo, filename, branch = 'main') {
5681
}
5782

5883
async function downloadModels() {
84+
const CACHE_FILE = path.join(MODEL_DIR, 'cache.json');
85+
const cache = new DownloadCache(CACHE_FILE, process.argv.includes('--force'));
5986
if (!fs.existsSync(MODEL_DIR)) {
6087
console.log(`Creating directory: ${MODEL_DIR}`);
6188
fs.mkdirSync(MODEL_DIR, { recursive: true });
@@ -71,6 +98,12 @@ async function downloadModels() {
7198
for (const modelInfo of MODELS_TO_DOWNLOAD) {
7299
const { id: modelId, task: modelTask, dtype: modelDType } = modelInfo;
73100

101+
const cacheKey = `${modelId}-${modelTask}-${modelDType}`;
102+
if (cache.has(cacheKey)) {
103+
console.log(`Model ${modelId} (${modelTask}, dtype: ${modelDType}) already cached. Skipping.`);
104+
continue;
105+
}
106+
74107
console.log(`Downloading files for ${modelId} (${modelTask}, dtype: ${modelDType})...`);
75108

76109
await pipeline(
@@ -82,34 +115,41 @@ async function downloadModels() {
82115
});
83116

84117
console.log(`Successfully downloaded and cached ${modelId}`);
118+
cache.put(cacheKey);
85119
}
86120

87121
// Download Marqo/marqo-fashionSigLIP model
88-
console.log(`Downloading files for Marqo/marqo-fashionSigLIP (zero-shot-image-classification, dtype: bnb4)...`);
89-
await SiglipVisionModel.from_pretrained("Marqo/marqo-fashionSigLIP" ,{
90-
cache_dir: env.localModelPath,
91-
dtype: 'bnb4'
92-
});
93-
await AutoImageProcessor.from_pretrained("Marqo/marqo-fashionSigLIP", {
94-
cache_dir: env.localModelPath,
95-
});
96-
await SiglipTextModel.from_pretrained("Marqo/marqo-fashionSigLIP", {
97-
cache_dir: env.localModelPath,
98-
dtype: 'bnb4'
99-
});
100-
await AutoTokenizer.from_pretrained("Marqo/marqo-fashionSigLIP", {
101-
cache_dir: env.localModelPath,
102-
});
103-
console.log(`Successfully downloaded and cached Marqo/marqo-fashionSigLIP`);
122+
console.log(`Checking Marqo/marqo-fashionSigLIP models...`);
123+
for (const modelInfo of MARGO_MODELS_TO_DOWNLOAD) {
124+
const cacheKey = `${modelInfo.class}-${modelInfo.dtype}`;
125+
if (cache.has(cacheKey)) {
126+
console.log(`Model ${modelInfo.class} (dtype: ${modelInfo.dtype}) already cached. Skipping.`);
127+
continue;
128+
}
129+
130+
console.log(`Downloading Marqo/marqo-fashionSigLIP (${modelInfo.class}${modelInfo.dtype ? `, dtype: ${modelInfo.dtype}` : ''})...`);
131+
await MARGO_NAME_TO_CLASS[modelInfo.class].from_pretrained("Marqo/marqo-fashionSigLIP", {
132+
cache_dir: env.localModelPath,
133+
dtype: modelInfo.dtype
134+
});
135+
136+
cache.put(cacheKey);
137+
}
138+
console.log(`Successfully checked Marqo/marqo-fashionSigLIP`);
104139

105140
// Download onnx-community/Kokoro-82M-v1.0-ONNX model
106-
console.log(`Starting manual download for ${KOKORO_REPO}...`);
141+
console.log(`Starting manual download check for ${KOKORO_REPO}...`);
107142
const kokoroModelPath = path.join(MODEL_DIR, KOKORO_REPO);
108143
if (!fs.existsSync(kokoroModelPath)) {
109144
fs.mkdirSync(kokoroModelPath, { recursive: true });
110145
}
111146

112147
for (const filename of KOKORO_FILES) {
148+
const cacheKey = `${KOKORO_REPO}-${filename}`;
149+
if (cache.has(cacheKey)) {
150+
console.log(` ${filename} already exists, skipping.`);
151+
continue;
152+
}
113153
const isOnnxFile = filename.endsWith('.onnx') || filename.endsWith('.onnx_data');
114154
const modelUrl = getHuggingFaceUrl(KOKORO_REPO, filename);
115155
let outputPath;
@@ -136,11 +176,13 @@ async function downloadModels() {
136176
response.body.on('error', reject);
137177
fileStream.on('finish', resolve);
138178
});
179+
180+
cache.put(cacheKey);
139181
} catch (err) {
140182
console.error(` Failed to download ${filename}:`, err.message);
141183
}
142184
}
143-
console.log(`Successfully downloaded all files for ${KOKORO_REPO}`);
185+
console.log(`Successfully checked all files for ${KOKORO_REPO}`);
144186

145187
} catch (err) {
146188
console.error("Model download failed:", err);
@@ -151,6 +193,6 @@ async function downloadModels() {
151193
}
152194

153195
downloadModels().catch(err => {
154-
console.error("Download process terminated.");
196+
console.error("Download process terminated.", err);
155197
process.exit(1);
156-
});
198+
});

0 commit comments

Comments
 (0)