-
Notifications
You must be signed in to change notification settings - Fork 0
Add Reranker service to lambda #365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
73110d6
Add Reranker service to lambda
bamader 851df44
Add some rounding to test scores
bamader 8b6e84a
Oops rounded wrong
bamader aebd88d
Add reranker into pipeline
bamader fc74ab2
Instantiate Reranker
bamader f5d5d7e
Reranker prediction
bamader 3d1797e
Code string update
bamader b601dee
Float cast
bamader 8be3aaf
Update reranker score
bamader e9dd82e
PR Feedback
bamader File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
11 changes: 7 additions & 4 deletions
11
packages/text-to-code/src/text_to_code/services/embedder.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,18 +1,21 @@ | ||
| from sentence_transformers import SentenceTransformer | ||
| from torch import Tensor | ||
|
|
||
| from text_to_code.models.registry import default_model | ||
| from text_to_code.models.registry import TTC_RETRIEVER | ||
|
|
||
| _MODEL = SentenceTransformer(default_model) | ||
| _RETRIEVER = SentenceTransformer(TTC_RETRIEVER) | ||
|
|
||
|
|
||
| class Embedder: | ||
| """Transforms nonstandard text.""" | ||
|
|
||
| def embed(self, text: str) -> Tensor: | ||
| """Take a text string and embeds it as a vector using a model as defined in config.py. | ||
| """Encode a text string into a vector representation. | ||
|
|
||
| The dimensionality and values of the vector form are determined | ||
| by the application's default Retriever Model. | ||
|
|
||
| :param text: Text string to embed. | ||
| :returns: Tensor representation of input text. | ||
| """ | ||
| return _MODEL.encode(text) | ||
| return _RETRIEVER.encode(text) |
26 changes: 26 additions & 0 deletions
26
packages/text-to-code/src/text_to_code/services/reranker.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| from sentence_transformers import CrossEncoder | ||
|
|
||
| from text_to_code.models.registry import TTC_RERANKER | ||
|
|
||
| _RERANKER = CrossEncoder(TTC_RERANKER) | ||
|
|
||
|
|
||
| class Reranker: | ||
| """Scores and sorts OpenSearch results.""" | ||
|
|
||
| def rerank(self, nonstandard_in: str, hits: list[str]) -> list[dict]: | ||
| """Re-sorts hits by cross-encoder score values. | ||
|
|
||
| Given a list of text strings returned from OpenSearch, score and sort | ||
| the search hits using the Text-to-Code system's default Reranker model. | ||
| The model will generate a cross-encoding score value measuring each | ||
| search result's information similarity to the original nonstandard input. | ||
|
|
||
| :param nonstandard_in: The original narrative free-text input to TTC. | ||
| :param hits: The list of OpenSearch results, in text string form. | ||
| :returns: A list of dictionaries representing the newly cross-encoder | ||
| scored search results, sorted in descending order of score. | ||
| """ | ||
| ranks = _RERANKER.rank(nonstandard_in, hits) | ||
| sorted_ranks = [{"code_string": hits[r["corpus_id"]], "score": r["score"]} for r in ranks] | ||
| return sorted_ranks |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| import pytest | ||
| from text_to_code.services.reranker import Reranker | ||
|
|
||
|
|
||
| class TestReranker: | ||
| @pytest.fixture(scope="class") | ||
| def reranker(self) -> Reranker: | ||
| return Reranker() | ||
|
|
||
| def test_reranker_empty_hits(self, reranker: Reranker) -> None: | ||
| ranks = reranker.rerank("Influenza virus A and B and SARS-CoV-2 (COVID-19)", []) | ||
| assert len(ranks) == 0 | ||
|
|
||
| def test_reranker_single_search_result(self, reranker: Reranker) -> None: | ||
| ranks = reranker.rerank( | ||
| "Influenza virus A and B and SARS-CoV-2 (COVID-19)", | ||
| ["Influenza virus A and B and SARS-CoV-2 (COVID-19)"], | ||
| ) | ||
| ranks = [ | ||
| {"code_string": r["code_string"], "score": round(float(r["score"]), 3)} for r in ranks | ||
| ] | ||
| assert ranks == [ | ||
| {"code_string": "Influenza virus A and B and SARS-CoV-2 (COVID-19)", "score": 0.973} | ||
| ] | ||
|
|
||
| def test_reranker_multiple_hits(self, reranker: Reranker) -> None: | ||
| nonstandard_in = "albumin/creatinine ratio (acr)" | ||
| search_hits = [ | ||
| "Albumin/Creatinine [Ratio] in Urine", | ||
| "Albumin/Creatinine (U) [Mass ratio]", | ||
| "Albumin/Creatinine [Ratio] in 24 hour Urine", | ||
| "Albumin/Creatinine (U) [Molar ratio]", | ||
| ] | ||
| ranks = reranker.rerank(nonstandard_in, search_hits) | ||
| ranks = [ | ||
| {"code_string": r["code_string"], "score": round(float(r["score"]), 3)} for r in ranks | ||
| ] | ||
| assert ranks == [ | ||
| { | ||
| "code_string": "Albumin/Creatinine (U) [Mass ratio]", | ||
| "score": 0.755, | ||
| }, | ||
| { | ||
| "code_string": "Albumin/Creatinine (U) [Molar ratio]", | ||
| "score": 0.73, | ||
| }, | ||
| { | ||
| "code_string": "Albumin/Creatinine [Ratio] in 24 hour Urine", | ||
| "score": 0.701, | ||
| }, | ||
| {"code_string": "Albumin/Creatinine [Ratio] in Urine", "score": 0.672}, | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.