Skip to content

Use Hugging Face HTTP API to determine task types #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import HFInferenceProviderClient from './hf.js';

const hf = new HFInferenceProviderClient({provider: 'replicate'});

// Hit the Replicate API to get the warm/cold status for the given model
const getModelStatus = async (model: InferenceModel) => {
try {
const [modelName, _modelVersion] = model.providerModel.split(":");
Expand All @@ -19,13 +20,25 @@ const getModelStatus = async (model: InferenceModel) => {
}
};

// Get Replicate model warm/cold statuses in parallel
const statuses = await Promise.all(inferenceModels.map(getModelStatus));
// Hit the Hugging Face API to get the Task™ type for the given model,
// e.g. "text-to-image", "image-to-image", "text-to-video", etc.
const getModelTask = async (model: InferenceModel) => {
const response = await fetch(`https://huggingface.co/api/models/${model.hfModel}`);
const data = await response.json() as { pipeline_tag: string };
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can potentially be undefined (but should be rare)

return data.pipeline_tag;
};

// Get Replicate model warm/cold statuses and Hugging Face tasks in parallel
const [statuses, tasks] = await Promise.all([
Promise.all(inferenceModels.map(getModelStatus)),
Promise.all(inferenceModels.map(getModelTask))
]);

// Set status (unless it's already manually defined on the model object)
// Set status and task (unless they're already manually defined on the model object)
const replicateModels = inferenceModels.map((model, index) => ({
...model,
status: model.status ?? statuses[index],
task: model.task ?? tasks[index],
}));

console.log("\n\nReplicate model mappings:");
Expand Down
19 changes: 1 addition & 18 deletions src/models.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export interface InferenceModel {
hfModel: string;
providerModel: string;
task: string;
task?: string;

// You can set this to force the value, e.g. to keep a model as 'staging' even if it's
// warm/live on Replicate. If not set, the status will be inferred from the provider model
Expand All @@ -12,68 +12,55 @@ export const inferenceModels: InferenceModel[] = [
{
hfModel: "deepseek-ai/DeepSeek-R1",
providerModel: "deepseek-ai/deepseek-r1",
task: "conversational",
status: "staging",
},
{
hfModel: "black-forest-labs/FLUX.1-dev",
providerModel: "black-forest-labs/flux-dev",
task: "text-to-image",
},
{
hfModel: "black-forest-labs/FLUX.1-schnell",
providerModel: "black-forest-labs/flux-schnell",
task: "text-to-image",
},
{
hfModel: "ByteDance/Hyper-SD",
providerModel: "bytedance/hyper-flux-16step:382cf8959fb0f0d665b26e7e80b8d6dc3faaef1510f14ce017e8c732bb3d1eb7",
task: "text-to-image",
},
{
hfModel: "ByteDance/SDXL-Lightning",
providerModel: "bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637",
task: "text-to-image",
},
{
hfModel: "playgroundai/playground-v2.5-1024px-aesthetic",
providerModel: "playgroundai/playground-v2.5-1024px-aesthetic:a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24",
task: "text-to-image",
},
{
hfModel: "stabilityai/stable-diffusion-3.5-large-turbo",
providerModel: "stability-ai/stable-diffusion-3.5-large-turbo",
task: "text-to-image",
},
{
hfModel: "stabilityai/stable-diffusion-3.5-large",
providerModel: "stability-ai/stable-diffusion-3.5-large",
task: "text-to-image",
},
{
hfModel: "stabilityai/stable-diffusion-3.5-medium",
providerModel: "stability-ai/stable-diffusion-3.5-medium",
task: "text-to-image",
},
{
hfModel: "stabilityai/stable-diffusion-xl-base-1.0",
providerModel: "stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc",
task: "text-to-image",
},
{
hfModel: "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
providerModel: "nvidia/sana-sprint-1.6b",
task: "text-to-image",
},
{
hfModel: "OuteAI/OuteTTS-0.3-500M",
providerModel: "jbilcke/oute-tts:39a59319327b27327fa3095149c5a746e7f2aee18c75055c3368237a6503cd26",
task: "text-to-speech",
},
{
hfModel: "genmo/mochi-1-preview",
providerModel: "genmoai/mochi-1:1944af04d098ef69bed7f9d335d102e652203f268ec4aaa2d836f6217217e460",
task: "text-to-video",
},
{
hfModel: "Wan-AI/Wan2.1-T2V-14B",
Expand All @@ -93,21 +80,17 @@ export const inferenceModels: InferenceModel[] = [
{
hfModel: "Wan-AI/Wan2.1-T2V-1.3B",
providerModel: "wan-video/wan-2.1-1.3b",
task: "text-to-video",
},
{
hfModel: "Lightricks/LTX-Video",
providerModel: "lightricks/ltx-video:8c47da666861d081eeb4d1261853087de23923a268a69b63febdf5dc1dee08e4",
task: "text-to-video",
},
{
hfModel: "zeke/rider-waite-tarot-flux",
providerModel: "tarot-cards/rider-waite:6d77a07ef88e8a09389385cb14d98b12629a4b23b0537b01dfeb833c32827546",
task: "text-to-image",
},
{
hfModel: "stepfun-ai/Step1X-Edit",
providerModel: "zsxkib/step1x-edit",
task: "image-to-image",
},
];