Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions resources/shared/download-cache.mjs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import fs from 'fs';

export default class DownloadCache {
cached = {};

constructor(filename, force) {
if (force) {
return;
}
this.filename = filename;
if (fs.existsSync(filename)) {
try {
this.cached = JSON.parse(fs.readFileSync(filename, 'utf8'));
} catch (err) {
console.warn(`Warning: Could not read cache file ${filename}:`, err.message);
}
}
}
has(key) {
return !!this.cached[key];
}
put(key) {
this.cached[key] = true;
fs.writeFileSync(this.filename, JSON.stringify(this.cached, null, 2));
}
}
82 changes: 62 additions & 20 deletions resources/transformers-js/src/download-models.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { KokoroTTS } from "kokoro-js";
import fs from 'fs';
import path from 'path';
import fetch from 'node-fetch';
import DownloadCache from '../../shared/download-cache.mjs';

const MODEL_DIR = './models';
env.localModelPath = MODEL_DIR;
Expand Down Expand Up @@ -48,6 +49,30 @@ const KOKORO_FILES = [
'tokenizer_config.json',
];

const MARGO_MODELS_TO_DOWNLOAD = [
{
class: 'SiglipVisionModel',
dtype: 'bnb4',
},
{
class: 'AutoImageProcessor',
},
{
class: 'SiglipTextModel',
dtype: 'bnb4',
},
{
class: 'AutoTokenizer',
dtype: 'bnb4',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I don't think AutoTokenizer has dtype.

}
];
const MARGO_NAME_TO_CLASS = {
'SiglipVisionModel': SiglipVisionModel,
'AutoImageProcessor': AutoImageProcessor,
'SiglipTextModel': SiglipTextModel,
'AutoTokenizer': AutoTokenizer,
};

function getHuggingFaceUrl(repo, filename, branch = 'main') {
if(filename.endsWith('.onnx')) {
return `https://huggingface.co/${repo}/resolve/${branch}/onnx/${filename}`;
Expand All @@ -56,6 +81,8 @@ function getHuggingFaceUrl(repo, filename, branch = 'main') {
}

async function downloadModels() {
const CACHE_FILE = path.join(MODEL_DIR, 'cache.json');
const cache = new DownloadCache(CACHE_FILE, process.argv.includes('--force'));
if (!fs.existsSync(MODEL_DIR)) {
console.log(`Creating directory: ${MODEL_DIR}`);
fs.mkdirSync(MODEL_DIR, { recursive: true });
Expand All @@ -71,6 +98,12 @@ async function downloadModels() {
for (const modelInfo of MODELS_TO_DOWNLOAD) {
const { id: modelId, task: modelTask, dtype: modelDType } = modelInfo;

const cacheKey = `${modelId}-${modelTask}-${modelDType}`;
if (cache.has(cacheKey)) {
console.log(`Model ${modelId} (${modelTask}, dtype: ${modelDType}) already cached. Skipping.`);
continue;
}

console.log(`Downloading files for ${modelId} (${modelTask}, dtype: ${modelDType})...`);

await pipeline(
Expand All @@ -82,34 +115,41 @@ async function downloadModels() {
});

console.log(`Successfully downloaded and cached ${modelId}`);
cache.put(cacheKey);
}

// Download Marqo/marqo-fashionSigLIP model
console.log(`Downloading files for Marqo/marqo-fashionSigLIP (zero-shot-image-classification, dtype: bnb4)...`);
await SiglipVisionModel.from_pretrained("Marqo/marqo-fashionSigLIP" ,{
cache_dir: env.localModelPath,
dtype: 'bnb4'
});
await AutoImageProcessor.from_pretrained("Marqo/marqo-fashionSigLIP", {
cache_dir: env.localModelPath,
});
await SiglipTextModel.from_pretrained("Marqo/marqo-fashionSigLIP", {
cache_dir: env.localModelPath,
dtype: 'bnb4'
});
await AutoTokenizer.from_pretrained("Marqo/marqo-fashionSigLIP", {
cache_dir: env.localModelPath,
});
console.log(`Successfully downloaded and cached Marqo/marqo-fashionSigLIP`);
console.log(`Checking Marqo/marqo-fashionSigLIP models...`);
for (const modelInfo of MARGO_MODELS_TO_DOWNLOAD) {
const cacheKey = `${modelInfo.class}-${modelInfo.dtype}`;
if (cache.has(cacheKey)) {
console.log(`Model ${modelInfo.class} (dtype: ${modelInfo.dtype}) already cached. Skipping.`);
continue;
}

console.log(`Downloading Marqo/marqo-fashionSigLIP (${modelInfo.class}${modelInfo.dtype ? `, dtype: ${modelInfo.dtype}` : ''})...`);
await MARGO_NAME_TO_CLASS[modelInfo.class].from_pretrained("Marqo/marqo-fashionSigLIP", {
cache_dir: env.localModelPath,
dtype: modelInfo.dtype
});

cache.put(cacheKey);
}
console.log(`Successfully checked Marqo/marqo-fashionSigLIP`);

// Download onnx-community/Kokoro-82M-v1.0-ONNX model
console.log(`Starting manual download for ${KOKORO_REPO}...`);
console.log(`Starting manual download check for ${KOKORO_REPO}...`);
const kokoroModelPath = path.join(MODEL_DIR, KOKORO_REPO);
if (!fs.existsSync(kokoroModelPath)) {
fs.mkdirSync(kokoroModelPath, { recursive: true });
}

for (const filename of KOKORO_FILES) {
const cacheKey = `${KOKORO_REPO}-${filename}`;
if (cache.has(cacheKey)) {
console.log(` ${filename} already exists, skipping.`);
continue;
}
const isOnnxFile = filename.endsWith('.onnx') || filename.endsWith('.onnx_data');
const modelUrl = getHuggingFaceUrl(KOKORO_REPO, filename);
let outputPath;
Expand All @@ -136,11 +176,13 @@ async function downloadModels() {
response.body.on('error', reject);
fileStream.on('finish', resolve);
});

cache.put(cacheKey);
} catch (err) {
console.error(` Failed to download ${filename}:`, err.message);
}
}
console.log(`Successfully downloaded all files for ${KOKORO_REPO}`);
console.log(`Successfully checked all files for ${KOKORO_REPO}`);

} catch (err) {
console.error("Model download failed:", err);
Expand All @@ -151,6 +193,6 @@ async function downloadModels() {
}

downloadModels().catch(err => {
console.error("Download process terminated.");
console.error("Download process terminated.", err);
process.exit(1);
});
});