Skip to content

Commit 7f02cf5

Browse files
committed
Fix kwargs propagation for sentence-transformers
1 parent e8af0f6 commit 7f02cf5

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/huggingface_inference_toolkit/sentence_transformers_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ def is_sentence_transformers_available():
1212

1313

1414
class SentenceSimilarityPipeline:
15-
def __init__(self, model_dir: str, device: str = None): # needs "cuda" for GPU
16-
self.model = SentenceTransformer(model_dir, device=device)
15+
def __init__(self, model_dir: str, device: str = None, **kwargs): # needs "cuda" for GPU
16+
self.model = SentenceTransformer(model_dir, device=device, **kwargs)
1717

1818
def __call__(self, inputs=None):
1919
embeddings1 = self.model.encode(
@@ -25,17 +25,17 @@ def __call__(self, inputs=None):
2525

2626

2727
class SentenceEmbeddingPipeline:
28-
def __init__(self, model_dir: str, device: str = None): # needs "cuda" for GPU
29-
self.model = SentenceTransformer(model_dir, device=device)
28+
def __init__(self, model_dir: str, device: str = None, **kwargs): # needs "cuda" for GPU
29+
self.model = SentenceTransformer(model_dir, device=device, **kwargs)
3030

3131
def __call__(self, inputs):
3232
embeddings = self.model.encode(inputs).tolist()
3333
return {"embeddings": embeddings}
3434

3535

3636
class RankingPipeline:
37-
def __init__(self, model_dir: str, device: str = None): # needs "cuda" for GPU
38-
self.model = CrossEncoder(model_dir, device=device)
37+
def __init__(self, model_dir: str, device: str = None, **kwargs): # needs "cuda" for GPU
38+
self.model = CrossEncoder(model_dir, device=device, **kwargs)
3939

4040
def __call__(self, inputs):
4141
scores = self.model.predict(inputs).tolist()

0 commit comments

Comments
 (0)