Skip to content

Commit 7e18606

Browse files
committed
add dataUrl and raise early
1 parent ba01d73 commit 7e18606

File tree

13 files changed

+102
-49
lines changed

13 files changed

+102
-49
lines changed

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { HF_HEADER_X_BILL_TO, HF_HUB_URL } from "../config.js";
22
import { PACKAGE_NAME, PACKAGE_VERSION } from "../package.js";
3-
import type { InferenceTask, InferenceProviderMappingEntry, Options, RequestArgs } from "../types.js";
3+
import type { InferenceTask, InferenceProviderMappingEntry, Options, OutputType, RequestArgs } from "../types.js";
44
import { getInferenceProviderMapping } from "./getInferenceProviderMapping.js";
55
import type { getProviderHelper } from "./getProviderHelper.js";
66
import { isUrl } from "./isUrl.js";
@@ -112,7 +112,7 @@ export function makeRequestOptionsFromResolvedModel(
112112
mapping: InferenceProviderMappingEntry | undefined,
113113
options?: Options & {
114114
task?: InferenceTask;
115-
outputType?: "url" | "blob" | "json";
115+
outputType?: OutputType;
116116
}
117117
): { url: string; info: RequestInit } {
118118
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;

packages/inference/src/providers/fal-ai.ts

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import { base64FromBytes } from "../utils/base64FromBytes.js";
1818

1919
import type { AutomaticSpeechRecognitionOutput, ImageSegmentationOutput } from "@huggingface/tasks";
2020
import { isUrl } from "../lib/isUrl.js";
21-
import type { BodyParams, HeaderParams, InferenceTask, ModelId, RequestArgs, UrlParams } from "../types.js";
21+
import type { BodyParams, HeaderParams, InferenceTask, ModelId, OutputType, RequestArgs, UrlParams } from "../types.js";
2222
import { delay } from "../utils/delay.js";
2323
import { omit } from "../utils/omit.js";
2424
import type { ImageSegmentationTaskHelper, ImageToImageTaskHelper } from "./providerHelper.js";
@@ -199,7 +199,7 @@ export class FalAITextToImageTask extends FalAiQueueTask implements TextToImageT
199199
response: FalAiQueueOutput,
200200
url?: string,
201201
headers?: Record<string, string>,
202-
outputType?: "url" | "blob" | "json"
202+
outputType?: OutputType
203203
): Promise<string | Blob | Record<string, unknown>> {
204204
const result = (await this.getResponseFromQueueApi(response, url, headers)) as FalAITextToImageOutput;
205205
if (
@@ -218,7 +218,12 @@ export class FalAITextToImageTask extends FalAiQueueTask implements TextToImageT
218218
return result.images[0].url;
219219
}
220220
const urlResponse = await fetch(result.images[0].url);
221-
return await urlResponse.blob();
221+
const blob = await urlResponse.blob();
222+
if (outputType === "dataUrl") {
223+
const b64 = await blob.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
224+
return `data:image/jpeg;base64,${b64}`;
225+
}
226+
return blob;
222227
}
223228

224229
throw new InferenceClientProviderOutputError(

packages/inference/src/providers/hf-inference.ts

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ import type {
3838
import { HF_ROUTER_URL } from "../config.js";
3939
import { InferenceClientInputError, InferenceClientProviderOutputError } from "../errors.js";
4040
import type { TabularClassificationOutput } from "../tasks/tabular/tabularClassification.js";
41-
import type { BodyParams, RequestArgs, UrlParams } from "../types.js";
41+
import type { BodyParams, OutputType, RequestArgs, UrlParams } from "../types.js";
4242
import { toArray } from "../utils/toArray.js";
4343
import type {
4444
AudioClassificationTaskHelper,
@@ -123,11 +123,20 @@ export class HFInferenceTask extends TaskProviderHelper {
123123
}
124124

125125
export class HFInferenceTextToImageTask extends HFInferenceTask implements TextToImageTaskHelper {
126+
override preparePayload(params: BodyParams): Record<string, unknown> {
127+
if (params.outputType === "url") {
128+
throw new InferenceClientInputError(
129+
"hf-inference provider does not support URL output. Use outputType 'blob', 'dataUrl' or 'json' instead."
130+
);
131+
}
132+
return params.args;
133+
}
134+
126135
override async getResponse(
127136
response: Base64ImageGeneration | OutputUrlImageGeneration,
128137
url?: string,
129138
headers?: HeadersInit,
130-
outputType?: "url" | "blob" | "json"
139+
outputType?: OutputType
131140
): Promise<string | Blob | Record<string, unknown>> {
132141
if (!response) {
133142
throw new InferenceClientProviderOutputError(
@@ -140,28 +149,29 @@ export class HFInferenceTextToImageTask extends HFInferenceTask implements TextT
140149
}
141150
if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
142151
const base64Data = response.data[0].b64_json;
143-
if (outputType === "url") {
144-
throw new InferenceClientInputError(
145-
"hf-inference provider does not support URL output for this model. Use outputType 'blob' or 'json' instead."
146-
);
152+
if (outputType === "dataUrl") {
153+
return `data:image/jpeg;base64,${base64Data}`;
147154
}
148155
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
149156
return await base64Response.blob();
150157
}
151158
if ("output" in response && Array.isArray(response.output)) {
152-
if (outputType === "url") {
153-
return response.output[0];
159+
if (outputType === "dataUrl") {
160+
// Fetch the URL and convert to dataUrl
161+
const urlResponse = await fetch(response.output[0]);
162+
const blob = await urlResponse.blob();
163+
const b64 = await blob.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
164+
return `data:image/jpeg;base64,${b64}`;
154165
}
155166
const urlResponse = await fetch(response.output[0]);
156167
const blob = await urlResponse.blob();
157168
return blob;
158169
}
159170
}
160171
if (response instanceof Blob) {
161-
if (outputType === "url") {
162-
throw new InferenceClientInputError(
163-
"hf-inference provider does not support URL output for this model. Use outputType 'blob' or 'json' instead."
164-
);
172+
if (outputType === "dataUrl") {
173+
const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
174+
return `data:image/jpeg;base64,${b64}`;
165175
}
166176
if (outputType === "json") {
167177
const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));

packages/inference/src/providers/hyperbolic.ts

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* Thanks!
1616
*/
1717
import type { ChatCompletionOutput, TextGenerationOutput } from "@huggingface/tasks";
18-
import type { BodyParams, UrlParams } from "../types.js";
18+
import type { BodyParams, OutputType, UrlParams } from "../types.js";
1919
import { omit } from "../utils/omit.js";
2020
import {
2121
BaseConversationalTask,
@@ -93,6 +93,11 @@ export class HyperbolicTextToImageTask extends TaskProviderHelper implements Tex
9393
}
9494

9595
preparePayload(params: BodyParams): Record<string, unknown> {
96+
if (params.outputType === "url") {
97+
throw new InferenceClientInputError(
98+
"hyperbolic provider does not support URL output. Use outputType 'blob', 'dataUrl' or 'json' instead."
99+
);
100+
}
96101
return {
97102
...omit(params.args, ["inputs", "parameters"]),
98103
...(params.args.parameters as Record<string, unknown>),
@@ -105,7 +110,7 @@ export class HyperbolicTextToImageTask extends TaskProviderHelper implements Tex
105110
response: HyperbolicTextToImageOutput,
106111
url?: string,
107112
headers?: HeadersInit,
108-
outputType?: "url" | "blob" | "json"
113+
outputType?: OutputType
109114
): Promise<string | Blob | Record<string, unknown>> {
110115
if (
111116
typeof response === "object" &&
@@ -117,10 +122,8 @@ export class HyperbolicTextToImageTask extends TaskProviderHelper implements Tex
117122
if (outputType === "json") {
118123
return { ...response };
119124
}
120-
if (outputType === "url") {
121-
throw new InferenceClientInputError(
122-
"hyperbolic provider does not support URL output. Use outputType 'blob' or 'json' instead."
123-
);
125+
if (outputType === "dataUrl") {
126+
return `data:image/jpeg;base64,${response.images[0].image}`;
124127
}
125128
return fetch(`data:image/jpeg;base64,${response.images[0].image}`).then((res) => res.blob());
126129
}

packages/inference/src/providers/nebius.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* Thanks!
1616
*/
1717
import type { FeatureExtractionOutput, TextGenerationOutput } from "@huggingface/tasks";
18-
import type { BodyParams } from "../types.js";
18+
import type { BodyParams, OutputType } from "../types.js";
1919
import { omit } from "../utils/omit.js";
2020
import {
2121
BaseConversationalTask,
@@ -117,7 +117,7 @@ export class NebiusTextToImageTask extends TaskProviderHelper implements TextToI
117117
response: NebiusImageGeneration,
118118
url?: string,
119119
headers?: HeadersInit,
120-
outputType?: "url" | "blob" | "json"
120+
outputType?: OutputType
121121
): Promise<string | Blob | Record<string, unknown>> {
122122
if (
123123
typeof response === "object" &&
@@ -135,6 +135,9 @@ export class NebiusTextToImageTask extends TaskProviderHelper implements TextToI
135135

136136
if ("b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") {
137137
const base64Data = response.data[0].b64_json;
138+
if (outputType === "dataUrl") {
139+
return `data:image/jpeg;base64,${base64Data}`;
140+
}
138141
return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
139142
}
140143
}

packages/inference/src/providers/nscale.ts

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* Thanks!
1616
*/
1717
import type { TextToImageInput } from "@huggingface/tasks";
18-
import type { BodyParams } from "../types.js";
18+
import type { BodyParams, OutputType } from "../types.js";
1919
import { omit } from "../utils/omit.js";
2020
import { BaseConversationalTask, TaskProviderHelper, type TextToImageTaskHelper } from "./providerHelper.js";
2121
import { InferenceClientInputError, InferenceClientProviderOutputError } from "../errors.js";
@@ -40,6 +40,11 @@ export class NscaleTextToImageTask extends TaskProviderHelper implements TextToI
4040
}
4141

4242
preparePayload(params: BodyParams<TextToImageInput>): Record<string, unknown> {
43+
if (params.outputType === "url") {
44+
throw new InferenceClientInputError(
45+
"nscale provider does not support URL output. Use outputType 'blob', 'dataUrl' or 'json' instead."
46+
);
47+
}
4348
return {
4449
...omit(params.args, ["inputs", "parameters"]),
4550
...params.args.parameters,
@@ -57,7 +62,7 @@ export class NscaleTextToImageTask extends TaskProviderHelper implements TextToI
5762
response: NscaleCloudBase64ImageGeneration,
5863
url?: string,
5964
headers?: HeadersInit,
60-
outputType?: "url" | "blob" | "json"
65+
outputType?: OutputType
6166
): Promise<string | Blob | Record<string, unknown>> {
6267
if (
6368
typeof response === "object" &&
@@ -71,10 +76,8 @@ export class NscaleTextToImageTask extends TaskProviderHelper implements TextToI
7176
return { ...response };
7277
}
7378
const base64Data = response.data[0].b64_json;
74-
if (outputType === "url") {
75-
throw new InferenceClientInputError(
76-
"nscale provider does not support URL output. Use outputType 'blob' or 'json' instead."
77-
);
79+
if (outputType === "dataUrl") {
80+
return `data:image/jpeg;base64,${base64Data}`;
7881
}
7982
return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
8083
}

packages/inference/src/providers/providerHelper.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ import type {
5151
import { HF_ROUTER_URL } from "../config.js";
5252
import { InferenceClientProviderOutputError, InferenceClientRoutingError } from "../errors.js";
5353
import type { AudioToAudioOutput } from "../tasks/audio/audioToAudio.js";
54-
import type { BaseArgs, BodyParams, HeaderParams, InferenceProvider, RequestArgs, UrlParams } from "../types.js";
54+
import type { BaseArgs, BodyParams, HeaderParams, InferenceProvider, OutputType, RequestArgs, UrlParams } from "../types.js";
5555
import { toArray } from "../utils/toArray.js";
5656
import type { ImageToImageArgs } from "../tasks/cv/imageToImage.js";
5757
import type { AutomaticSpeechRecognitionArgs } from "../tasks/audio/automaticSpeechRecognition.js";
@@ -78,7 +78,7 @@ export abstract class TaskProviderHelper {
7878
response: unknown,
7979
url?: string,
8080
headers?: HeadersInit,
81-
outputType?: "url" | "blob"
81+
outputType?: OutputType
8282
): Promise<unknown>;
8383

8484
/**
@@ -141,7 +141,7 @@ export interface TextToImageTaskHelper {
141141
response: unknown,
142142
url?: string,
143143
headers?: HeadersInit,
144-
outputType?: "url" | "blob" | "json"
144+
outputType?: OutputType
145145
): Promise<string | Blob | Record<string, unknown>>;
146146
preparePayload(params: BodyParams<TextToImageInput & BaseArgs>): Record<string, unknown>;
147147
}

packages/inference/src/providers/replicate.ts

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717
import { InferenceClientProviderOutputError } from "../errors.js";
1818
import { isUrl } from "../lib/isUrl.js";
19-
import type { BodyParams, HeaderParams, RequestArgs, UrlParams } from "../types.js";
19+
import type { BodyParams, HeaderParams, OutputType, RequestArgs, UrlParams } from "../types.js";
2020
import { omit } from "../utils/omit.js";
2121
import {
2222
TaskProviderHelper,
@@ -91,7 +91,7 @@ export class ReplicateTextToImageTask extends ReplicateTask implements TextToIma
9191
res: ReplicateOutput | Blob,
9292
url?: string,
9393
headers?: Record<string, string>,
94-
outputType?: "url" | "blob" | "json"
94+
outputType?: OutputType
9595
): Promise<string | Blob | Record<string, unknown>> {
9696
void url;
9797
void headers;
@@ -105,7 +105,12 @@ export class ReplicateTextToImageTask extends ReplicateTask implements TextToIma
105105
return res.output;
106106
}
107107
const urlResponse = await fetch(res.output);
108-
return await urlResponse.blob();
108+
const blob = await urlResponse.blob();
109+
if (outputType === "dataUrl") {
110+
const b64 = await blob.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
111+
return `data:image/jpeg;base64,${b64}`;
112+
}
113+
return blob;
109114
}
110115

111116
// Handle array output
@@ -123,7 +128,12 @@ export class ReplicateTextToImageTask extends ReplicateTask implements TextToIma
123128
return res.output[0];
124129
}
125130
const urlResponse = await fetch(res.output[0]);
126-
return await urlResponse.blob();
131+
const blob = await urlResponse.blob();
132+
if (outputType === "dataUrl") {
133+
const b64 = await blob.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
134+
return `data:image/jpeg;base64,${b64}`;
135+
}
136+
return blob;
127137
}
128138

129139
throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-to-image API");

packages/inference/src/providers/together.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* Thanks!
1616
*/
1717
import type { ChatCompletionOutput, TextGenerationOutput, TextGenerationOutputFinishReason } from "@huggingface/tasks";
18-
import type { BodyParams } from "../types.js";
18+
import type { BodyParams, OutputType } from "../types.js";
1919
import { omit } from "../utils/omit.js";
2020
import {
2121
BaseConversationalTask,
@@ -119,7 +119,7 @@ export class TogetherTextToImageTask extends TaskProviderHelper implements TextT
119119
response: TogetherImageGeneration,
120120
url?: string,
121121
headers?: HeadersInit,
122-
outputType?: "url" | "blob" | "json"
122+
outputType?: OutputType
123123
): Promise<string | Blob | Record<string, unknown>> {
124124
if (
125125
typeof response === "object" &&
@@ -137,6 +137,9 @@ export class TogetherTextToImageTask extends TaskProviderHelper implements TextT
137137

138138
if ("b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") {
139139
const base64Data = response.data[0].b64_json;
140+
if (outputType === "dataUrl") {
141+
return `data:image/jpeg;base64,${base64Data}`;
142+
}
140143
return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
141144
}
142145
}

packages/inference/src/providers/wavespeed.ts

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import type { TextToImageArgs } from "../tasks/cv/textToImage.js";
22
import type { ImageToImageArgs } from "../tasks/cv/imageToImage.js";
33
import type { TextToVideoArgs } from "../tasks/cv/textToVideo.js";
44
import type { ImageToVideoArgs } from "../tasks/cv/imageToVideo.js";
5-
import type { BodyParams, RequestArgs, UrlParams } from "../types.js";
5+
import type { BodyParams, OutputType, RequestArgs, UrlParams } from "../types.js";
66
import { delay } from "../utils/delay.js";
77
import { omit } from "../utils/omit.js";
88
import { base64FromBytes } from "../utils/base64FromBytes.js";
@@ -116,7 +116,7 @@ abstract class WavespeedAITask extends TaskProviderHelper {
116116
response: WaveSpeedAISubmitTaskResponse,
117117
url?: string,
118118
headers?: Record<string, string>,
119-
outputType?: "url" | "blob" | "json"
119+
outputType?: OutputType
120120
): Promise<string | Blob | Record<string, unknown>> {
121121
if (!url || !headers) {
122122
throw new InferenceClientInputError("Headers are required for WaveSpeed AI API calls");
@@ -179,7 +179,12 @@ abstract class WavespeedAITask extends TaskProviderHelper {
179179
}
180180
);
181181
}
182-
return await mediaResponse.blob();
182+
const blob = await mediaResponse.blob();
183+
if (outputType === "dataUrl") {
184+
const b64 = await blob.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
185+
return `data:image/jpeg;base64,${b64}`;
186+
}
187+
return blob;
183188
}
184189
case "failed": {
185190
throw new InferenceClientProviderOutputError(taskResult.error || "Task failed");

0 commit comments

Comments
 (0)