From d8da13b366a5dd3fb927507407a43de7b35b79de Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Mon, 5 May 2025 16:02:48 -0700 Subject: [PATCH] defer to HF API for task types --- src/index.ts | 19 ++++++++++++++++--- src/models.ts | 19 +------------------ 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/src/index.ts b/src/index.ts index ed81225..0157e02 100644 --- a/src/index.ts +++ b/src/index.ts @@ -5,6 +5,7 @@ import HFInferenceProviderClient from './hf.js'; const hf = new HFInferenceProviderClient({provider: 'replicate'}); +// Hit the Replicate API to get the warm/cold status for the given model const getModelStatus = async (model: InferenceModel) => { try { const [modelName, _modelVersion] = model.providerModel.split(":"); @@ -19,13 +20,25 @@ const getModelStatus = async (model: InferenceModel) => { } }; -// Get Replicate model warm/cold statuses in parallel -const statuses = await Promise.all(inferenceModels.map(getModelStatus)); +// Hit the Hugging Face API to get the Taskā„¢ type for the given model, +// e.g. "text-to-image", "image-to-image", "text-to-video", etc. +const getModelTask = async (model: InferenceModel) => { + const response = await fetch(`https://huggingface.co/api/models/${model.hfModel}`); + const data = await response.json() as { pipeline_tag: string }; + return data.pipeline_tag; +}; + +// Get Replicate model warm/cold statuses and Hugging Face tasks in parallel +const [statuses, tasks] = await Promise.all([ + Promise.all(inferenceModels.map(getModelStatus)), + Promise.all(inferenceModels.map(getModelTask)) +]); -// Set status (unless it's already manually defined on the model object) +// Set status and task (unless they're already manually defined on the model object) const replicateModels = inferenceModels.map((model, index) => ({ ...model, status: model.status ?? statuses[index], + task: model.task ?? tasks[index], })); console.log("\n\nReplicate model mappings:"); diff --git a/src/models.ts b/src/models.ts index 030460f..b0ee9d6 100644 --- a/src/models.ts +++ b/src/models.ts @@ -1,7 +1,7 @@ export interface InferenceModel { hfModel: string; providerModel: string; - task: string; + task?: string; // You can set this to force the value, e.g. to keep a model as 'staging' even if it's // warm/live on Replicate. If not set, the status will be inferred from the provider model @@ -12,68 +12,55 @@ export const inferenceModels: InferenceModel[] = [ { hfModel: "deepseek-ai/DeepSeek-R1", providerModel: "deepseek-ai/deepseek-r1", - task: "conversational", status: "staging", }, { hfModel: "black-forest-labs/FLUX.1-dev", providerModel: "black-forest-labs/flux-dev", - task: "text-to-image", }, { hfModel: "black-forest-labs/FLUX.1-schnell", providerModel: "black-forest-labs/flux-schnell", - task: "text-to-image", }, { hfModel: "ByteDance/Hyper-SD", providerModel: "bytedance/hyper-flux-16step:382cf8959fb0f0d665b26e7e80b8d6dc3faaef1510f14ce017e8c732bb3d1eb7", - task: "text-to-image", }, { hfModel: "ByteDance/SDXL-Lightning", providerModel: "bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637", - task: "text-to-image", }, { hfModel: "playgroundai/playground-v2.5-1024px-aesthetic", providerModel: "playgroundai/playground-v2.5-1024px-aesthetic:a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24", - task: "text-to-image", }, { hfModel: "stabilityai/stable-diffusion-3.5-large-turbo", providerModel: "stability-ai/stable-diffusion-3.5-large-turbo", - task: "text-to-image", }, { hfModel: "stabilityai/stable-diffusion-3.5-large", providerModel: "stability-ai/stable-diffusion-3.5-large", - task: "text-to-image", }, { hfModel: "stabilityai/stable-diffusion-3.5-medium", providerModel: "stability-ai/stable-diffusion-3.5-medium", - task: "text-to-image", }, { hfModel: "stabilityai/stable-diffusion-xl-base-1.0", providerModel: "stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc", - task: "text-to-image", }, { hfModel: "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", providerModel: "nvidia/sana-sprint-1.6b", - task: "text-to-image", }, { hfModel: "OuteAI/OuteTTS-0.3-500M", providerModel: "jbilcke/oute-tts:39a59319327b27327fa3095149c5a746e7f2aee18c75055c3368237a6503cd26", - task: "text-to-speech", }, { hfModel: "genmo/mochi-1-preview", providerModel: "genmoai/mochi-1:1944af04d098ef69bed7f9d335d102e652203f268ec4aaa2d836f6217217e460", - task: "text-to-video", }, { hfModel: "Wan-AI/Wan2.1-T2V-14B", @@ -93,21 +80,17 @@ export const inferenceModels: InferenceModel[] = [ { hfModel: "Wan-AI/Wan2.1-T2V-1.3B", providerModel: "wan-video/wan-2.1-1.3b", - task: "text-to-video", }, { hfModel: "Lightricks/LTX-Video", providerModel: "lightricks/ltx-video:8c47da666861d081eeb4d1261853087de23923a268a69b63febdf5dc1dee08e4", - task: "text-to-video", }, { hfModel: "zeke/rider-waite-tarot-flux", providerModel: "tarot-cards/rider-waite:6d77a07ef88e8a09389385cb14d98b12629a4b23b0537b01dfeb833c32827546", - task: "text-to-image", }, { hfModel: "stepfun-ai/Step1X-Edit", providerModel: "zsxkib/step1x-edit", - task: "image-to-image", }, ];