Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { HF_HEADER_X_BILL_TO, HF_HUB_URL } from "../config.js";
import { PACKAGE_NAME, PACKAGE_VERSION } from "../package.js";
import type { InferenceTask, InferenceProviderMappingEntry, Options, RequestArgs } from "../types.js";
import type { InferenceTask, InferenceProviderMappingEntry, Options, OutputType, RequestArgs } from "../types.js";
import { getInferenceProviderMapping } from "./getInferenceProviderMapping.js";
import type { getProviderHelper } from "./getProviderHelper.js";
import { isUrl } from "./isUrl.js";
Expand Down Expand Up @@ -112,6 +112,7 @@ export function makeRequestOptionsFromResolvedModel(
mapping: InferenceProviderMappingEntry | undefined,
options?: Options & {
task?: InferenceTask;
outputType?: OutputType;
}
): { url: string; info: RequestInit } {
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
Expand All @@ -120,7 +121,7 @@ export function makeRequestOptionsFromResolvedModel(

const provider = providerHelper.provider;

const { includeCredentials, task, signal, billTo } = options ?? {};
const { includeCredentials, task, signal, billTo, outputType } = options ?? {};
const authMethod = (() => {
if (providerHelper.clientSideRoutingOnly) {
// Closed-source providers require an accessToken (cannot be routed).
Expand Down Expand Up @@ -172,6 +173,7 @@ export function makeRequestOptionsFromResolvedModel(
model: resolvedModel,
task,
mapping,
outputType,
});
/**
* For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error
Expand Down
11 changes: 8 additions & 3 deletions packages/inference/src/providers/fal-ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import { base64FromBytes } from "../utils/base64FromBytes.js";

import type { AutomaticSpeechRecognitionOutput, ImageSegmentationOutput } from "@huggingface/tasks";
import { isUrl } from "../lib/isUrl.js";
import type { BodyParams, HeaderParams, InferenceTask, ModelId, RequestArgs, UrlParams } from "../types.js";
import type { BodyParams, HeaderParams, InferenceTask, ModelId, OutputType, RequestArgs, UrlParams } from "../types.js";
import { delay } from "../utils/delay.js";
import { omit } from "../utils/omit.js";
import type { ImageSegmentationTaskHelper, ImageToImageTaskHelper } from "./providerHelper.js";
Expand Down Expand Up @@ -199,7 +199,7 @@ export class FalAITextToImageTask extends FalAiQueueTask implements TextToImageT
response: FalAiQueueOutput,
url?: string,
headers?: Record<string, string>,
outputType?: "url" | "blob" | "json"
outputType?: OutputType
): Promise<string | Blob | Record<string, unknown>> {
const result = (await this.getResponseFromQueueApi(response, url, headers)) as FalAITextToImageOutput;
if (
Expand All @@ -218,7 +218,12 @@ export class FalAITextToImageTask extends FalAiQueueTask implements TextToImageT
return result.images[0].url;
}
const urlResponse = await fetch(result.images[0].url);
return await urlResponse.blob();
const blob = await urlResponse.blob();
if (outputType === "dataUrl") {
const b64 = await blob.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
return `data:image/jpeg;base64,${b64}`;
}
return blob;
}

throw new InferenceClientProviderOutputError(
Expand Down
33 changes: 25 additions & 8 deletions packages/inference/src/providers/hf-inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ import type {
ZeroShotImageClassificationOutput,
} from "@huggingface/tasks";
import { HF_ROUTER_URL } from "../config.js";
import { InferenceClientProviderOutputError } from "../errors.js";
import { InferenceClientInputError, InferenceClientProviderOutputError } from "../errors.js";
import type { TabularClassificationOutput } from "../tasks/tabular/tabularClassification.js";
import type { BodyParams, RequestArgs, UrlParams } from "../types.js";
import type { BodyParams, OutputType, RequestArgs, UrlParams } from "../types.js";
import { toArray } from "../utils/toArray.js";
import type {
AudioClassificationTaskHelper,
Expand Down Expand Up @@ -123,11 +123,20 @@ export class HFInferenceTask extends TaskProviderHelper {
}

export class HFInferenceTextToImageTask extends HFInferenceTask implements TextToImageTaskHelper {
override preparePayload(params: BodyParams): Record<string, unknown> {
if (params.outputType === "url") {
throw new InferenceClientInputError(
"hf-inference provider does not support URL output. Use outputType 'blob', 'dataUrl' or 'json' instead."
);
}
return params.args;
}

override async getResponse(
response: Base64ImageGeneration | OutputUrlImageGeneration,
url?: string,
headers?: HeadersInit,
outputType?: "url" | "blob" | "json"
outputType?: OutputType
): Promise<string | Blob | Record<string, unknown>> {
if (!response) {
throw new InferenceClientProviderOutputError(
Expand All @@ -140,25 +149,33 @@ export class HFInferenceTextToImageTask extends HFInferenceTask implements TextT
}
if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
const base64Data = response.data[0].b64_json;
if (outputType === "url") {
if (outputType === "dataUrl") {
return `data:image/jpeg;base64,${base64Data}`;
}
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
return await base64Response.blob();
}
if ("output" in response && Array.isArray(response.output)) {
if (outputType === "url") {
return response.output[0];
if (outputType === "dataUrl") {
// Fetch the URL and convert to dataUrl
const urlResponse = await fetch(response.output[0]);
const blob = await urlResponse.blob();
const b64 = await blob.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
return `data:image/jpeg;base64,${b64}`;
}
const urlResponse = await fetch(response.output[0]);
const blob = await urlResponse.blob();
return blob;
}
}
if (response instanceof Blob) {
if (outputType === "url" || outputType === "json") {
if (outputType === "dataUrl") {
const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
return `data:image/jpeg;base64,${b64}`;
}
if (outputType === "json") {
const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
return outputType === "url" ? `data:image/jpeg;base64,${b64}` : { output: `data:image/jpeg;base64,${b64}` };
return { output: `data:image/jpeg;base64,${b64}` };
}
return response;
}
Expand Down
13 changes: 9 additions & 4 deletions packages/inference/src/providers/hyperbolic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
* Thanks!
*/
import type { ChatCompletionOutput, TextGenerationOutput } from "@huggingface/tasks";
import type { BodyParams, UrlParams } from "../types.js";
import type { BodyParams, OutputType, UrlParams } from "../types.js";
import { omit } from "../utils/omit.js";
import {
BaseConversationalTask,
BaseTextGenerationTask,
TaskProviderHelper,
type TextToImageTaskHelper,
} from "./providerHelper.js";
import { InferenceClientProviderOutputError } from "../errors.js";
import { InferenceClientInputError, InferenceClientProviderOutputError } from "../errors.js";
const HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";

export interface HyperbolicTextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
Expand Down Expand Up @@ -93,6 +93,11 @@ export class HyperbolicTextToImageTask extends TaskProviderHelper implements Tex
}

preparePayload(params: BodyParams): Record<string, unknown> {
if (params.outputType === "url") {
throw new InferenceClientInputError(
"hyperbolic provider does not support URL output. Use outputType 'blob', 'dataUrl' or 'json' instead."
);
}
return {
...omit(params.args, ["inputs", "parameters"]),
...(params.args.parameters as Record<string, unknown>),
Expand All @@ -105,7 +110,7 @@ export class HyperbolicTextToImageTask extends TaskProviderHelper implements Tex
response: HyperbolicTextToImageOutput,
url?: string,
headers?: HeadersInit,
outputType?: "url" | "blob" | "json"
outputType?: OutputType
): Promise<string | Blob | Record<string, unknown>> {
if (
typeof response === "object" &&
Expand All @@ -117,7 +122,7 @@ export class HyperbolicTextToImageTask extends TaskProviderHelper implements Tex
if (outputType === "json") {
return { ...response };
}
if (outputType === "url") {
if (outputType === "dataUrl") {
return `data:image/jpeg;base64,${response.images[0].image}`;
}
return fetch(`data:image/jpeg;base64,${response.images[0].image}`).then((res) => res.blob());
Expand Down
32 changes: 19 additions & 13 deletions packages/inference/src/providers/nebius.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* Thanks!
*/
import type { FeatureExtractionOutput, TextGenerationOutput } from "@huggingface/tasks";
import type { BodyParams } from "../types.js";
import type { BodyParams, OutputType } from "../types.js";
import { omit } from "../utils/omit.js";
import {
BaseConversationalTask,
Expand All @@ -29,9 +29,10 @@ import type { ChatCompletionInput } from "../../../tasks/dist/commonjs/index.js"

const NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";

interface NebiusBase64ImageGeneration {
interface NebiusImageGeneration {
data: Array<{
b64_json: string;
b64_json?: string;
url?: string;
}>;
}

Expand Down Expand Up @@ -102,7 +103,7 @@ export class NebiusTextToImageTask extends TaskProviderHelper implements TextToI
return {
...omit(params.args, ["inputs", "parameters"]),
...(params.args.parameters as Record<string, unknown>),
response_format: "b64_json",
response_format: params.outputType === "url" ? "url" : "b64_json",
prompt: params.args.inputs,
model: params.model,
};
Expand All @@ -113,27 +114,32 @@ export class NebiusTextToImageTask extends TaskProviderHelper implements TextToI
}

async getResponse(
response: NebiusBase64ImageGeneration,
response: NebiusImageGeneration,
url?: string,
headers?: HeadersInit,
outputType?: "url" | "blob" | "json"
outputType?: OutputType
): Promise<string | Blob | Record<string, unknown>> {
if (
typeof response === "object" &&
"data" in response &&
Array.isArray(response.data) &&
response.data.length > 0 &&
"b64_json" in response.data[0] &&
typeof response.data[0].b64_json === "string"
response.data.length > 0
) {
if (outputType === "json") {
return { ...response };
}
const base64Data = response.data[0].b64_json;
if (outputType === "url") {
return `data:image/jpeg;base64,${base64Data}`;

if ("url" in response.data[0] && typeof response.data[0].url === "string") {
return response.data[0].url;
}

if ("b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") {
const base64Data = response.data[0].b64_json;
if (outputType === "dataUrl") {
return `data:image/jpeg;base64,${base64Data}`;
}
return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
}
return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
}

throw new InferenceClientProviderOutputError("Received malformed response from Nebius text-to-image API");
Expand Down
13 changes: 9 additions & 4 deletions packages/inference/src/providers/nscale.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
* Thanks!
*/
import type { TextToImageInput } from "@huggingface/tasks";
import type { BodyParams } from "../types.js";
import type { BodyParams, OutputType } from "../types.js";
import { omit } from "../utils/omit.js";
import { BaseConversationalTask, TaskProviderHelper, type TextToImageTaskHelper } from "./providerHelper.js";
import { InferenceClientProviderOutputError } from "../errors.js";
import { InferenceClientInputError, InferenceClientProviderOutputError } from "../errors.js";

const NSCALE_API_BASE_URL = "https://inference.api.nscale.com";

Expand All @@ -40,6 +40,11 @@ export class NscaleTextToImageTask extends TaskProviderHelper implements TextToI
}

preparePayload(params: BodyParams<TextToImageInput>): Record<string, unknown> {
if (params.outputType === "url") {
throw new InferenceClientInputError(
"nscale provider does not support URL output. Use outputType 'blob', 'dataUrl' or 'json' instead."
);
}
return {
...omit(params.args, ["inputs", "parameters"]),
...params.args.parameters,
Expand All @@ -57,7 +62,7 @@ export class NscaleTextToImageTask extends TaskProviderHelper implements TextToI
response: NscaleCloudBase64ImageGeneration,
url?: string,
headers?: HeadersInit,
outputType?: "url" | "blob" | "json"
outputType?: OutputType
): Promise<string | Blob | Record<string, unknown>> {
if (
typeof response === "object" &&
Expand All @@ -71,7 +76,7 @@ export class NscaleTextToImageTask extends TaskProviderHelper implements TextToI
return { ...response };
}
const base64Data = response.data[0].b64_json;
if (outputType === "url") {
if (outputType === "dataUrl") {
return `data:image/jpeg;base64,${base64Data}`;
}
return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
Expand Down
14 changes: 11 additions & 3 deletions packages/inference/src/providers/providerHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,15 @@ import type {
import { HF_ROUTER_URL } from "../config.js";
import { InferenceClientProviderOutputError, InferenceClientRoutingError } from "../errors.js";
import type { AudioToAudioOutput } from "../tasks/audio/audioToAudio.js";
import type { BaseArgs, BodyParams, HeaderParams, InferenceProvider, RequestArgs, UrlParams } from "../types.js";
import type {
BaseArgs,
BodyParams,
HeaderParams,
InferenceProvider,
OutputType,
RequestArgs,
UrlParams,
} from "../types.js";
import { toArray } from "../utils/toArray.js";
import type { ImageToImageArgs } from "../tasks/cv/imageToImage.js";
import type { AutomaticSpeechRecognitionArgs } from "../tasks/audio/automaticSpeechRecognition.js";
Expand All @@ -78,7 +86,7 @@ export abstract class TaskProviderHelper {
response: unknown,
url?: string,
headers?: HeadersInit,
outputType?: "url" | "blob"
outputType?: OutputType
): Promise<unknown>;

/**
Expand Down Expand Up @@ -141,7 +149,7 @@ export interface TextToImageTaskHelper {
response: unknown,
url?: string,
headers?: HeadersInit,
outputType?: "url" | "blob" | "json"
outputType?: OutputType
): Promise<string | Blob | Record<string, unknown>>;
preparePayload(params: BodyParams<TextToImageInput & BaseArgs>): Record<string, unknown>;
}
Expand Down
18 changes: 14 additions & 4 deletions packages/inference/src/providers/replicate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
import { InferenceClientProviderOutputError } from "../errors.js";
import { isUrl } from "../lib/isUrl.js";
import type { BodyParams, HeaderParams, RequestArgs, UrlParams } from "../types.js";
import type { BodyParams, HeaderParams, OutputType, RequestArgs, UrlParams } from "../types.js";
import { omit } from "../utils/omit.js";
import {
TaskProviderHelper,
Expand Down Expand Up @@ -91,7 +91,7 @@ export class ReplicateTextToImageTask extends ReplicateTask implements TextToIma
res: ReplicateOutput | Blob,
url?: string,
headers?: Record<string, string>,
outputType?: "url" | "blob" | "json"
outputType?: OutputType
): Promise<string | Blob | Record<string, unknown>> {
void url;
void headers;
Expand All @@ -105,7 +105,12 @@ export class ReplicateTextToImageTask extends ReplicateTask implements TextToIma
return res.output;
}
const urlResponse = await fetch(res.output);
return await urlResponse.blob();
const blob = await urlResponse.blob();
if (outputType === "dataUrl") {
const b64 = await blob.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
return `data:image/jpeg;base64,${b64}`;
}
return blob;
}

// Handle array output
Expand All @@ -123,7 +128,12 @@ export class ReplicateTextToImageTask extends ReplicateTask implements TextToIma
return res.output[0];
}
const urlResponse = await fetch(res.output[0]);
return await urlResponse.blob();
const blob = await urlResponse.blob();
if (outputType === "dataUrl") {
const b64 = await blob.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
return `data:image/jpeg;base64,${b64}`;
}
return blob;
}

throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-to-image API");
Expand Down
Loading