Skip to content

Commit 0d4fab1

Browse files
committed
Skip downloading transformer models if they already exist.
Implement a caching mechanism using a JSON file to track the state of downloaded models. This avoids redundant downloads during prebuild if the model configuration has not changed. A --force flag is also added to allow users to manually bypass the cache and re-download all models when necessary.
1 parent 5f8d674 commit 0d4fab1

File tree

1 file changed

+56
-14
lines changed

1 file changed

+56
-14
lines changed

resources/transformers-js/src/download-models.mjs

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,30 @@ const KOKORO_FILES = [
4848
'tokenizer_config.json',
4949
];
5050

51+
const MARGO_MODELS_TO_DOWNLOAD = [
52+
{
53+
class: 'SiglipVisionModel',
54+
dtype: 'bnb4',
55+
},
56+
{
57+
class: 'AutoImageProcessor',
58+
},
59+
{
60+
class: 'SiglipTextModel',
61+
dtype: 'bnb4',
62+
},
63+
{
64+
class: 'AutoTokenizer',
65+
dtype: 'bnb4',
66+
}
67+
];
68+
const MARGO_NAME_TO_CLASS = {
69+
'SiglipVisionModel': SiglipVisionModel,
70+
'AutoImageProcessor': AutoImageProcessor,
71+
'SiglipTextModel': SiglipTextModel,
72+
'AutoTokenizer': AutoTokenizer,
73+
};
74+
5175
function getHuggingFaceUrl(repo, filename, branch = 'main') {
5276
if(filename.endsWith('.onnx')) {
5377
return `https://huggingface.co/${repo}/resolve/${branch}/onnx/${filename}`;
@@ -56,6 +80,28 @@ function getHuggingFaceUrl(repo, filename, branch = 'main') {
5680
}
5781

5882
async function downloadModels() {
83+
const CACHE_FILE = path.join(MODEL_DIR, 'cache.json');
84+
const currentConfig = {
85+
MODELS_TO_DOWNLOAD,
86+
KOKORO_REPO,
87+
KOKORO_FILES,
88+
MARGO_MODELS_TO_DOWNLOAD,
89+
};
90+
91+
const force = process.argv.includes('--force');
92+
93+
if (!force && fs.existsSync(CACHE_FILE)) {
94+
try {
95+
const cachedConfig = JSON.parse(fs.readFileSync(CACHE_FILE, 'utf8'));
96+
if (JSON.stringify(cachedConfig) === JSON.stringify(currentConfig)) {
97+
console.log('Models are already up to date. Skipping download. Use --force to override.');
98+
return;
99+
}
100+
} catch (err) {
101+
console.warn(`Warning: Could not read cache file ${CACHE_FILE}:`, err.message);
102+
}
103+
}
104+
59105
if (!fs.existsSync(MODEL_DIR)) {
60106
console.log(`Creating directory: ${MODEL_DIR}`);
61107
fs.mkdirSync(MODEL_DIR, { recursive: true });
@@ -86,20 +132,12 @@ async function downloadModels() {
86132

87133
// Download Marqo/marqo-fashionSigLIP model
88134
console.log(`Downloading files for Marqo/marqo-fashionSigLIP (zero-shot-image-classification, dtype: bnb4)...`);
89-
await SiglipVisionModel.from_pretrained("Marqo/marqo-fashionSigLIP" ,{
90-
cache_dir: env.localModelPath,
91-
dtype: 'bnb4'
92-
});
93-
await AutoImageProcessor.from_pretrained("Marqo/marqo-fashionSigLIP", {
94-
cache_dir: env.localModelPath,
95-
});
96-
await SiglipTextModel.from_pretrained("Marqo/marqo-fashionSigLIP", {
97-
cache_dir: env.localModelPath,
98-
dtype: 'bnb4'
99-
});
100-
await AutoTokenizer.from_pretrained("Marqo/marqo-fashionSigLIP", {
101-
cache_dir: env.localModelPath,
102-
});
135+
for (const modelInfo of MARGO_MODELS_TO_DOWNLOAD) {
136+
await MARGO_NAME_TO_CLASS[modelInfo.class].from_pretrained("Marqo/marqo-fashionSigLIP", {
137+
cache_dir: env.localModelPath,
138+
dtype: modelInfo.dtype
139+
});
140+
}
103141
console.log(`Successfully downloaded and cached Marqo/marqo-fashionSigLIP`);
104142

105143
// Download onnx-community/Kokoro-82M-v1.0-ONNX model
@@ -142,6 +180,10 @@ async function downloadModels() {
142180
}
143181
console.log(`Successfully downloaded all files for ${KOKORO_REPO}`);
144182

183+
// Update cache file
184+
fs.writeFileSync(CACHE_FILE, JSON.stringify(currentConfig, null, 2));
185+
console.log(`Updated cache file: ${CACHE_FILE}\n`);
186+
145187
} catch (err) {
146188
console.error("Model download failed:", err);
147189
env.allowRemoteModels = originalAllowRemote;

0 commit comments

Comments
 (0)