diff --git a/README.md b/README.md index 4118689..2112e6f 100644 --- a/README.md +++ b/README.md @@ -33,9 +33,29 @@ Simple usecase: ```python >>> import truecase >>> truecase.get_true_case('hey, what is the weather in new york?') -'Hey, what is the weather in New York?'' +'Hey, what is the weather in New York?' ``` +You can also pass an `out_of_vocabulary_token_option`, which will be used if a word is not found in the model's vocabulary: +```python +>>> import truecase +>>> truecase.get_true_case('my favorite music genre is hip-hop.', "title") +'My favorite music genre is Hip-Hop.' +``` +`out_of_vocabulary_token_option`: +- "title" < DEFAULT +- "capitalize" +- "lower" +- Or, pass if your own lambda function (takes the token with original casing as a single parameter) + +*If an invalid option is passed, title is used* + +Lambda function example: +```python +>>> import truecase +>>> truecase.get_true_case('i work in the nsa.', lambda token: token.upper()) +'I work in the NSA.' +``` ## Training your own model TODO. For now refer to Trainer.py diff --git a/tests/test_truecase.py b/tests/test_truecase.py index d6a5edf..016c511 100644 --- a/tests/test_truecase.py +++ b/tests/test_truecase.py @@ -47,3 +47,8 @@ def test_get_true_case(self): expected = "Testing $bug" result = self.tc.get_true_case(sentence) assert result == expected + + sentence = "i work in the nsa." + expected = "I work in the NSA." + result = self.tc.get_true_case(sentence, lambda token: token.upper()) + assert result == expected diff --git a/truecase/TrueCaser.py b/truecase/TrueCaser.py index 29253fb..fed333d 100644 --- a/truecase/TrueCaser.py +++ b/truecase/TrueCaser.py @@ -2,6 +2,7 @@ import os import pickle import string +from typing import Callable import nltk from nltk.tokenize import word_tokenize @@ -91,6 +92,19 @@ def get_score(self, prev_token, possible_token, next_token): def first_token_case(self, raw): return raw.capitalize() + def out_of_vocabulary_handler(self, token_og_case, out_of_vocabulary_token_option="title"): + if isinstance(out_of_vocabulary_token_option, Callable): + return out_of_vocabulary_token_option(token_og_case) + elif out_of_vocabulary_token_option == "title": + return token_og_case.title() + elif out_of_vocabulary_token_option == "capitalize": + return token_og_case.capitalize() + elif out_of_vocabulary_token_option == "lower": + return token_og_case.lower() + else: + # If value passed is invalid, use .title() + return token_og_case.title() + def get_true_case(self, sentence, out_of_vocabulary_token_option="title"): """ Wrapper function for handling untokenized input. @@ -121,7 +135,7 @@ def get_true_case_from_tokens(self, tokens, out_of_vocabulary_token_option="titl """ tokens_true_case = [] for token_idx, token in enumerate(tokens): - + token_og_case = token if token in string.punctuation or token.isdigit(): tokens_true_case.append(token) else: @@ -154,14 +168,7 @@ def get_true_case_from_tokens(self, tokens, out_of_vocabulary_token_option="titl tokens_true_case[0]) else: # Token out of vocabulary - if out_of_vocabulary_token_option == "title": - tokens_true_case.append(token.title()) - elif out_of_vocabulary_token_option == "capitalize": - tokens_true_case.append(token.capitalize()) - elif out_of_vocabulary_token_option == "lower": - tokens_true_case.append(token.lower()) - else: - tokens_true_case.append(token) + tokens_true_case.append(self.out_of_vocabulary_handler(token_og_case, out_of_vocabulary_token_option)) return tokens_true_case