@@ -38,7 +38,7 @@ import type {
3838import { HF_ROUTER_URL } from "../config.js" ;
3939import { InferenceClientInputError , InferenceClientProviderOutputError } from "../errors.js" ;
4040import type { TabularClassificationOutput } from "../tasks/tabular/tabularClassification.js" ;
41- import type { BodyParams , RequestArgs , UrlParams } from "../types.js" ;
41+ import type { BodyParams , OutputType , RequestArgs , UrlParams } from "../types.js" ;
4242import { toArray } from "../utils/toArray.js" ;
4343import type {
4444 AudioClassificationTaskHelper ,
@@ -123,11 +123,20 @@ export class HFInferenceTask extends TaskProviderHelper {
123123}
124124
125125export class HFInferenceTextToImageTask extends HFInferenceTask implements TextToImageTaskHelper {
126+ override preparePayload ( params : BodyParams ) : Record < string , unknown > {
127+ if ( params . outputType === "url" ) {
128+ throw new InferenceClientInputError (
129+ "hf-inference provider does not support URL output. Use outputType 'blob', 'dataUrl' or 'json' instead."
130+ ) ;
131+ }
132+ return params . args ;
133+ }
134+
126135 override async getResponse (
127136 response : Base64ImageGeneration | OutputUrlImageGeneration ,
128137 url ?: string ,
129138 headers ?: HeadersInit ,
130- outputType ?: "url" | "blob" | "json"
139+ outputType ?: OutputType
131140 ) : Promise < string | Blob | Record < string , unknown > > {
132141 if ( ! response ) {
133142 throw new InferenceClientProviderOutputError (
@@ -140,28 +149,29 @@ export class HFInferenceTextToImageTask extends HFInferenceTask implements TextT
140149 }
141150 if ( "data" in response && Array . isArray ( response . data ) && response . data [ 0 ] . b64_json ) {
142151 const base64Data = response . data [ 0 ] . b64_json ;
143- if ( outputType === "url" ) {
144- throw new InferenceClientInputError (
145- "hf-inference provider does not support URL output for this model. Use outputType 'blob' or 'json' instead."
146- ) ;
152+ if ( outputType === "dataUrl" ) {
153+ return `data:image/jpeg;base64,${ base64Data } ` ;
147154 }
148155 const base64Response = await fetch ( `data:image/jpeg;base64,${ base64Data } ` ) ;
149156 return await base64Response . blob ( ) ;
150157 }
151158 if ( "output" in response && Array . isArray ( response . output ) ) {
152- if ( outputType === "url" ) {
153- return response . output [ 0 ] ;
159+ if ( outputType === "dataUrl" ) {
160+ // Fetch the URL and convert to dataUrl
161+ const urlResponse = await fetch ( response . output [ 0 ] ) ;
162+ const blob = await urlResponse . blob ( ) ;
163+ const b64 = await blob . arrayBuffer ( ) . then ( ( buf ) => Buffer . from ( buf ) . toString ( "base64" ) ) ;
164+ return `data:image/jpeg;base64,${ b64 } ` ;
154165 }
155166 const urlResponse = await fetch ( response . output [ 0 ] ) ;
156167 const blob = await urlResponse . blob ( ) ;
157168 return blob ;
158169 }
159170 }
160171 if ( response instanceof Blob ) {
161- if ( outputType === "url" ) {
162- throw new InferenceClientInputError (
163- "hf-inference provider does not support URL output for this model. Use outputType 'blob' or 'json' instead."
164- ) ;
172+ if ( outputType === "dataUrl" ) {
173+ const b64 = await response . arrayBuffer ( ) . then ( ( buf ) => Buffer . from ( buf ) . toString ( "base64" ) ) ;
174+ return `data:image/jpeg;base64,${ b64 } ` ;
165175 }
166176 if ( outputType === "json" ) {
167177 const b64 = await response . arrayBuffer ( ) . then ( ( buf ) => Buffer . from ( buf ) . toString ( "base64" ) ) ;
0 commit comments