|
16 | 16 | import os |
17 | 17 | import sys |
18 | 18 |
|
| 19 | +import numpy as np |
19 | 20 | import paddle |
20 | 21 | from paddle import inference |
21 | 22 | from scipy import spatial |
22 | 23 |
|
23 | 24 | from paddlenlp.data import Pad, Tuple |
24 | 25 | from paddlenlp.transformers import AutoTokenizer |
| 26 | +from paddlenlp.utils.env import ( |
| 27 | + PADDLE_INFERENCE_MODEL_SUFFIX, |
| 28 | + PADDLE_INFERENCE_WEIGHTS_SUFFIX, |
| 29 | +) |
25 | 30 | from paddlenlp.utils.log import logger |
26 | 31 |
|
27 | 32 | sys.path.append(".") |
@@ -90,8 +95,8 @@ def __init__( |
90 | 95 | self.max_seq_length = max_seq_length |
91 | 96 | self.batch_size = batch_size |
92 | 97 |
|
93 | | - model_file = model_dir + "/inference.get_pooled_embedding.pdmodel" |
94 | | - params_file = model_dir + "/inference.get_pooled_embedding.pdiparams" |
| 98 | + model_file = model_dir + f"/inference{PADDLE_INFERENCE_MODEL_SUFFIX}" |
| 99 | + params_file = model_dir + f"/inference{PADDLE_INFERENCE_WEIGHTS_SUFFIX}" |
95 | 100 | if not os.path.exists(model_file): |
96 | 101 | raise ValueError("not find model file path {}".format(model_file)) |
97 | 102 | if not os.path.exists(params_file): |
@@ -238,6 +243,9 @@ def predict(self, data, tokenizer): |
238 | 243 |
|
239 | 244 | if args.benchmark: |
240 | 245 | self.autolog.times.end(stamp=True) |
| 246 | + |
| 247 | + query_logits = np.atleast_2d(query_logits) |
| 248 | + title_logits = np.atleast_2d(title_logits) |
241 | 249 | result = [float(1 - spatial.distance.cosine(arr1, arr2)) for arr1, arr2 in zip(query_logits, title_logits)] |
242 | 250 | return result |
243 | 251 |
|
|
0 commit comments