@@ -12,8 +12,8 @@ def is_sentence_transformers_available():
12
12
13
13
14
14
class SentenceSimilarityPipeline :
15
- def __init__ (self , model_dir : str , device : str = None ): # needs "cuda" for GPU
16
- self .model = SentenceTransformer (model_dir , device = device )
15
+ def __init__ (self , model_dir : str , device : str = None , ** kwargs ): # needs "cuda" for GPU
16
+ self .model = SentenceTransformer (model_dir , device = device , ** kwargs )
17
17
18
18
def __call__ (self , inputs = None ):
19
19
embeddings1 = self .model .encode (
@@ -25,17 +25,17 @@ def __call__(self, inputs=None):
25
25
26
26
27
27
class SentenceEmbeddingPipeline :
28
- def __init__ (self , model_dir : str , device : str = None ): # needs "cuda" for GPU
29
- self .model = SentenceTransformer (model_dir , device = device )
28
+ def __init__ (self , model_dir : str , device : str = None , ** kwargs ): # needs "cuda" for GPU
29
+ self .model = SentenceTransformer (model_dir , device = device , ** kwargs )
30
30
31
31
def __call__ (self , inputs ):
32
32
embeddings = self .model .encode (inputs ).tolist ()
33
33
return {"embeddings" : embeddings }
34
34
35
35
36
36
class RankingPipeline :
37
- def __init__ (self , model_dir : str , device : str = None ): # needs "cuda" for GPU
38
- self .model = CrossEncoder (model_dir , device = device )
37
+ def __init__ (self , model_dir : str , device : str = None , ** kwargs ): # needs "cuda" for GPU
38
+ self .model = CrossEncoder (model_dir , device = device , ** kwargs )
39
39
40
40
def __call__ (self , inputs ):
41
41
scores = self .model .predict (inputs ).tolist ()
0 commit comments