|
4 | 4 |
|
5 | 5 | from huggingface_inference_toolkit.const import HF_TRUST_REMOTE_CODE
|
6 | 6 | from huggingface_inference_toolkit.env_utils import api_inference_compat
|
| 7 | +from huggingface_inference_toolkit import logging |
7 | 8 | from huggingface_inference_toolkit.sentence_transformers_utils import SENTENCE_TRANSFORMERS_TASKS
|
8 | 9 | from huggingface_inference_toolkit.utils import (
|
9 | 10 | check_and_register_custom_pipeline_from_directory,
|
@@ -106,17 +107,52 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
106 | 107 | if self.pipeline.task == "text-classification" and isinstance(inputs, str):
|
107 | 108 | inputs = [inputs]
|
108 | 109 | parameters.setdefault("top_k", os.environ.get("DEFAULT_TOP_K", 5))
|
109 |
| - resp = self.pipeline(inputs, **parameters) |
110 |
| - # # We don't want to return {} |
| 110 | + if self.pipeline.task == "token-classification": |
| 111 | + parameters.setdefault("aggregation_strategy", os.environ.get("DEFAULT_AGGREGATION_STRATEGY", "simple")) |
| 112 | + |
| 113 | + resp = self.pipeline(**inputs, **parameters) if isinstance(inputs, dict) else \ |
| 114 | + self.pipeline(inputs, **parameters) |
| 115 | + |
| 116 | + if api_inference_compat(): |
| 117 | + if self.pipeline.task == "text-classification": |
| 118 | + # We don't want to return {} but [{}] in any case |
111 | 119 | if isinstance(resp, list) and len(resp) > 0:
|
112 | 120 | if not isinstance(resp[0], list):
|
113 | 121 | return [resp]
|
114 | 122 | return resp
|
115 |
| - if self.pipeline.task == "token-classification": |
116 |
| - parameters.setdefault("aggregation_strategy", os.environ.get("DEFAULT_AGGREGATION_STRATEGY", "simple")) |
117 |
| - return ( |
118 |
| - self.pipeline(**inputs, **parameters) if isinstance(inputs, dict) else self.pipeline(inputs, **parameters) # type: ignore |
119 |
| - ) |
| 123 | + if self.pipeline.task == "feature-extraction": |
| 124 | + # If the library used is Transformers then the feature-extraction is returning the headless encoder |
| 125 | + # outputs as embeddings. The shape is a 3D or 4D array |
| 126 | + # [n_inputs, batch_size = 1, n_sentence_tokens, num_hidden_dim]. |
| 127 | + # Let's just discard the batch size dim that always seems to be 1 and return a 2D/3D array |
| 128 | + # https://github.com/huggingface/transformers/blob/5c47d08b0d6835b8d8fc1c06d9a1bc71f6e78ace/src/transformers/pipelines/feature_extraction.py#L27 |
| 129 | + # for api inference (reason: mainly display) |
| 130 | + new_resp = [] |
| 131 | + if isinstance(inputs, list): |
| 132 | + if isinstance(resp, list) and len(resp) == len(inputs): |
| 133 | + for it in resp: |
| 134 | + # Batch size dim is the first it level, dicard it |
| 135 | + if isinstance(it, list) and len(it) == 1: |
| 136 | + new_resp.append(it[0]) |
| 137 | + else: |
| 138 | + logging.logger.warning("One of the output batch size differs from 1: %d", len(it)) |
| 139 | + return resp |
| 140 | + return new_resp |
| 141 | + else: |
| 142 | + logging.logger.warning("Inputs and resp len differ (or resp is not a list, type %s)", |
| 143 | + type(resp)) |
| 144 | + return resp |
| 145 | + elif isinstance(inputs, str): |
| 146 | + if isinstance(resp, list) and len(resp) == 1: |
| 147 | + return resp[0] |
| 148 | + else: |
| 149 | + logging.logger.warning("The output batch size differs from 1: %d", len(resp)) |
| 150 | + return resp |
| 151 | + else: |
| 152 | + logging.logger.warning("Output unexpected type %s", type(resp)) |
| 153 | + return resp |
| 154 | + |
| 155 | + return resp |
120 | 156 |
|
121 | 157 |
|
122 | 158 | class VertexAIHandler(HuggingFaceHandler):
|
|
0 commit comments