diff --git a/backend/python/rerankers/backend.py b/backend/python/rerankers/backend.py index c9a80eab4be8..aadb5b9afcae 100755 --- a/backend/python/rerankers/backend.py +++ b/backend/python/rerankers/backend.py @@ -61,7 +61,7 @@ def LoadModel(self, request, context): if request.PipelineType != "": # Reuse the PipelineType field for language kwargs['lang'] = request.PipelineType self.model_name = model_name - self.model = Reranker(model_name, **kwargs) + self.model = Reranker(model_name, **kwargs) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") @@ -80,7 +80,7 @@ def Rerank(self, request, context): index=res.doc_id, text=res.text, relevance_score=res.score - ) for res in ranked_results.results + ) for res in ranked_results.top_k(request.top_n) ] # Calculate the usage and total tokens diff --git a/backend/python/rerankers/test.py b/backend/python/rerankers/test.py index d3e4e075b8b4..3f2ddf0b7700 100755 --- a/backend/python/rerankers/test.py +++ b/backend/python/rerankers/test.py @@ -86,5 +86,33 @@ def test_rerank(self): except Exception as err: print(err) self.fail("Reranker service failed") + finally: + self.tearDown() + + def test_rerank_crop(self): + """ + This method tests if the embeddings are generated successfully + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + request = backend_pb2.RerankRequest( + query="I love you", + documents=["I hate you", "I really like you", "I hate ignoring top_n"], + top_n=2 + ) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder")) + self.assertTrue(response.success) + + rerank_response = stub.Rerank(request) + print(rerank_response.results[0]) + self.assertIsNotNone(rerank_response.results) + self.assertEqual(len(rerank_response.results), 2) + self.assertEqual(rerank_response.results[0].text, "I really like you") + self.assertEqual(rerank_response.results[1].text, "I hate you") + except Exception as err: + print(err) + self.fail("Reranker service failed") finally: self.tearDown() \ No newline at end of file