Skip to content

Commit 571fba8

Browse files
authored
Skip downloading transformer models if they already exist. (#64)
Implement a caching mechanism using a JSON file to track the state of downloaded models. This avoids redundant downloads during prebuild if the model configuration has not changed. A --force flag is also added to allow users to manually bypass the cache and re-download all models when necessary.
1 parent 86777b1 commit 571fba8

File tree

2 files changed

+87
-20
lines changed

2 files changed

+87
-20
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import fs from 'fs';
2+
3+
export default class DownloadCache {
4+
cached = {};
5+
6+
constructor(filename, force) {
7+
if (force) {
8+
return;
9+
}
10+
this.filename = filename;
11+
if (fs.existsSync(filename)) {
12+
try {
13+
this.cached = JSON.parse(fs.readFileSync(filename, 'utf8'));
14+
} catch (err) {
15+
console.warn(`Warning: Could not read cache file ${filename}:`, err.message);
16+
}
17+
}
18+
}
19+
has(key) {
20+
return !!this.cached[key];
21+
}
22+
put(key) {
23+
this.cached[key] = true;
24+
fs.writeFileSync(this.filename, JSON.stringify(this.cached, null, 2));
25+
}
26+
}

resources/transformers-js/src/download-models.mjs

Lines changed: 61 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { KokoroTTS } from "kokoro-js";
33
import fs from 'fs';
44
import path from 'path';
55
import fetch from 'node-fetch';
6+
import DownloadCache from '../../shared/download-cache.mjs';
67

78
const MODEL_DIR = './models';
89
env.localModelPath = MODEL_DIR;
@@ -48,6 +49,29 @@ const KOKORO_FILES = [
4849
'tokenizer_config.json',
4950
];
5051

52+
const MARGO_MODELS_TO_DOWNLOAD = [
53+
{
54+
class: 'SiglipVisionModel',
55+
dtype: 'bnb4',
56+
},
57+
{
58+
class: 'AutoImageProcessor',
59+
},
60+
{
61+
class: 'SiglipTextModel',
62+
dtype: 'bnb4',
63+
},
64+
{
65+
class: 'AutoTokenizer',
66+
}
67+
];
68+
const MARGO_NAME_TO_CLASS = {
69+
'SiglipVisionModel': SiglipVisionModel,
70+
'AutoImageProcessor': AutoImageProcessor,
71+
'SiglipTextModel': SiglipTextModel,
72+
'AutoTokenizer': AutoTokenizer,
73+
};
74+
5175
function getHuggingFaceUrl(repo, filename, branch = 'main') {
5276
if(filename.endsWith('.onnx')) {
5377
return `https://huggingface.co/${repo}/resolve/${branch}/onnx/${filename}`;
@@ -56,6 +80,8 @@ function getHuggingFaceUrl(repo, filename, branch = 'main') {
5680
}
5781

5882
async function downloadModels() {
83+
const CACHE_FILE = path.join(MODEL_DIR, 'cache.json');
84+
const cache = new DownloadCache(CACHE_FILE, process.argv.includes('--force'));
5985
if (!fs.existsSync(MODEL_DIR)) {
6086
console.log(`Creating directory: ${MODEL_DIR}`);
6187
fs.mkdirSync(MODEL_DIR, { recursive: true });
@@ -71,6 +97,12 @@ async function downloadModels() {
7197
for (const modelInfo of MODELS_TO_DOWNLOAD) {
7298
const { id: modelId, task: modelTask, dtype: modelDType } = modelInfo;
7399

100+
const cacheKey = `${modelId}-${modelTask}-${modelDType}`;
101+
if (cache.has(cacheKey)) {
102+
console.log(`Model ${modelId} (${modelTask}, dtype: ${modelDType}) already cached. Skipping.`);
103+
continue;
104+
}
105+
74106
console.log(`Downloading files for ${modelId} (${modelTask}, dtype: ${modelDType})...`);
75107

76108
await pipeline(
@@ -82,34 +114,41 @@ async function downloadModels() {
82114
});
83115

84116
console.log(`Successfully downloaded and cached ${modelId}`);
117+
cache.put(cacheKey);
85118
}
86119

87120
// Download Marqo/marqo-fashionSigLIP model
88-
console.log(`Downloading files for Marqo/marqo-fashionSigLIP (zero-shot-image-classification, dtype: bnb4)...`);
89-
await SiglipVisionModel.from_pretrained("Marqo/marqo-fashionSigLIP" ,{
90-
cache_dir: env.localModelPath,
91-
dtype: 'bnb4'
92-
});
93-
await AutoImageProcessor.from_pretrained("Marqo/marqo-fashionSigLIP", {
94-
cache_dir: env.localModelPath,
95-
});
96-
await SiglipTextModel.from_pretrained("Marqo/marqo-fashionSigLIP", {
97-
cache_dir: env.localModelPath,
98-
dtype: 'bnb4'
99-
});
100-
await AutoTokenizer.from_pretrained("Marqo/marqo-fashionSigLIP", {
101-
cache_dir: env.localModelPath,
102-
});
103-
console.log(`Successfully downloaded and cached Marqo/marqo-fashionSigLIP`);
121+
console.log(`Checking Marqo/marqo-fashionSigLIP models...`);
122+
for (const modelInfo of MARGO_MODELS_TO_DOWNLOAD) {
123+
const cacheKey = `${modelInfo.class}-${modelInfo.dtype}`;
124+
if (cache.has(cacheKey)) {
125+
console.log(`Model ${modelInfo.class} (dtype: ${modelInfo.dtype}) already cached. Skipping.`);
126+
continue;
127+
}
128+
129+
console.log(`Downloading Marqo/marqo-fashionSigLIP (${modelInfo.class}${modelInfo.dtype ? `, dtype: ${modelInfo.dtype}` : ''})...`);
130+
await MARGO_NAME_TO_CLASS[modelInfo.class].from_pretrained("Marqo/marqo-fashionSigLIP", {
131+
cache_dir: env.localModelPath,
132+
dtype: modelInfo.dtype
133+
});
134+
135+
cache.put(cacheKey);
136+
}
137+
console.log(`Successfully checked Marqo/marqo-fashionSigLIP`);
104138

105139
// Download onnx-community/Kokoro-82M-v1.0-ONNX model
106-
console.log(`Starting manual download for ${KOKORO_REPO}...`);
140+
console.log(`Starting manual download check for ${KOKORO_REPO}...`);
107141
const kokoroModelPath = path.join(MODEL_DIR, KOKORO_REPO);
108142
if (!fs.existsSync(kokoroModelPath)) {
109143
fs.mkdirSync(kokoroModelPath, { recursive: true });
110144
}
111145

112146
for (const filename of KOKORO_FILES) {
147+
const cacheKey = `${KOKORO_REPO}-${filename}`;
148+
if (cache.has(cacheKey)) {
149+
console.log(` ${filename} already exists, skipping.`);
150+
continue;
151+
}
113152
const isOnnxFile = filename.endsWith('.onnx') || filename.endsWith('.onnx_data');
114153
const modelUrl = getHuggingFaceUrl(KOKORO_REPO, filename);
115154
let outputPath;
@@ -136,11 +175,13 @@ async function downloadModels() {
136175
response.body.on('error', reject);
137176
fileStream.on('finish', resolve);
138177
});
178+
179+
cache.put(cacheKey);
139180
} catch (err) {
140181
console.error(` Failed to download ${filename}:`, err.message);
141182
}
142183
}
143-
console.log(`Successfully downloaded all files for ${KOKORO_REPO}`);
184+
console.log(`Successfully checked all files for ${KOKORO_REPO}`);
144185

145186
} catch (err) {
146187
console.error("Model download failed:", err);
@@ -151,6 +192,6 @@ async function downloadModels() {
151192
}
152193

153194
downloadModels().catch(err => {
154-
console.error("Download process terminated.");
195+
console.error("Download process terminated.", err);
155196
process.exit(1);
156-
});
197+
});

0 commit comments

Comments
 (0)