Skip to content

Commit 39db7c6

Browse files
committed
fix(api-inference): feature-extraction, flatten array, discard the batch size dim
Signed-off-by: Raphael Glon <[email protected]>
1 parent e94b6fb commit 39db7c6

File tree

1 file changed

+43
-7
lines changed

1 file changed

+43
-7
lines changed

src/huggingface_inference_toolkit/handler.py

+43-7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from huggingface_inference_toolkit.const import HF_TRUST_REMOTE_CODE
66
from huggingface_inference_toolkit.env_utils import api_inference_compat
7+
from huggingface_inference_toolkit import logging
78
from huggingface_inference_toolkit.sentence_transformers_utils import SENTENCE_TRANSFORMERS_TASKS
89
from huggingface_inference_toolkit.utils import (
910
check_and_register_custom_pipeline_from_directory,
@@ -106,17 +107,52 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
106107
if self.pipeline.task == "text-classification" and isinstance(inputs, str):
107108
inputs = [inputs]
108109
parameters.setdefault("top_k", os.environ.get("DEFAULT_TOP_K", 5))
109-
resp = self.pipeline(inputs, **parameters)
110-
# # We don't want to return {}
110+
if self.pipeline.task == "token-classification":
111+
parameters.setdefault("aggregation_strategy", os.environ.get("DEFAULT_AGGREGATION_STRATEGY", "simple"))
112+
113+
resp = self.pipeline(**inputs, **parameters) if isinstance(inputs, dict) else \
114+
self.pipeline(inputs, **parameters)
115+
116+
if api_inference_compat():
117+
if self.pipeline.task == "text-classification":
118+
# We don't want to return {} but [{}] in any case
111119
if isinstance(resp, list) and len(resp) > 0:
112120
if not isinstance(resp[0], list):
113121
return [resp]
114122
return resp
115-
if self.pipeline.task == "token-classification":
116-
parameters.setdefault("aggregation_strategy", os.environ.get("DEFAULT_AGGREGATION_STRATEGY", "simple"))
117-
return (
118-
self.pipeline(**inputs, **parameters) if isinstance(inputs, dict) else self.pipeline(inputs, **parameters) # type: ignore
119-
)
123+
if self.pipeline.task == "feature-extraction":
124+
# If the library used is Transformers then the feature-extraction is returning the headless encoder
125+
# outputs as embeddings. The shape is a 3D or 4D array
126+
# [n_inputs, batch_size = 1, n_sentence_tokens, num_hidden_dim].
127+
# Let's just discard the batch size dim that always seems to be 1 and return a 2D/3D array
128+
# https://github.com/huggingface/transformers/blob/5c47d08b0d6835b8d8fc1c06d9a1bc71f6e78ace/src/transformers/pipelines/feature_extraction.py#L27
129+
# for api inference (reason: mainly display)
130+
new_resp = []
131+
if isinstance(inputs, list):
132+
if isinstance(resp, list) and len(resp) == len(inputs):
133+
for it in resp:
134+
# Batch size dim is the first it level, dicard it
135+
if isinstance(it, list) and len(it) == 1:
136+
new_resp.append(it[0])
137+
else:
138+
logging.logger.warning("One of the output batch size differs from 1: %d", len(it))
139+
return resp
140+
return new_resp
141+
else:
142+
logging.logger.warning("Inputs and resp len differ (or resp is not a list, type %s)",
143+
type(resp))
144+
return resp
145+
elif isinstance(inputs, str):
146+
if isinstance(resp, list) and len(resp) == 1:
147+
return resp[0]
148+
else:
149+
logging.logger.warning("The output batch size differs from 1: %d", len(resp))
150+
return resp
151+
else:
152+
logging.logger.warning("Output unexpected type %s", type(resp))
153+
return resp
154+
155+
return resp
120156

121157

122158
class VertexAIHandler(HuggingFaceHandler):

0 commit comments

Comments
 (0)