From 825a6dbf7bbe90fbacc8df8629c35e81ba4acdea Mon Sep 17 00:00:00 2001 From: Keshav Pradeep <32313895+keshprad@users.noreply.github.com> Date: Fri, 2 Jul 2021 10:51:13 -0700 Subject: [PATCH] Use word_tokenize in combo with TreebankWordDetokenizer. Other small changes too (#21) * use capitalize for first_token_case and an out_of_vocab option * refactor nominator to numerator (more descriptive) * Use word_tokenizer instead of TweetTokenizer and TreebankWordDetokenizer to join tokens * Add test cases * Update __init__.py Co-authored-by: daltonfury42 --- .gitignore | 5 ++++- tests/test_truecase.py | 20 ++++++++++++++++++-- truecase/TrueCaser.py | 36 ++++++++++++++++++------------------ truecase/__init__.py | 2 +- 4 files changed, 41 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index 590105d..fe99238 100644 --- a/.gitignore +++ b/.gitignore @@ -77,4 +77,7 @@ fabric.properties .pytest_cache build/ -.eggs/ \ No newline at end of file +.eggs/ + +# virtual env +.env/ \ No newline at end of file diff --git a/tests/test_truecase.py b/tests/test_truecase.py index a66408f..c9e6689 100644 --- a/tests/test_truecase.py +++ b/tests/test_truecase.py @@ -10,14 +10,30 @@ def setUp(self): def test_get_true_case(self): sentence = "I live in barcelona." expected = "I live in Barcelona." - result = self.tc.get_true_case(sentence) - assert result == expected sentence = "My name is irvine wels" expected = "My name is Irvine Wels" + result = self.tc.get_true_case(sentence) + assert result == expected + + sentence = "i paid $50 FOR My shoes." + expected = "I paid $50 for my shoes." + result = self.tc.get_true_case(sentence) + assert result == expected + + sentence = "Ron'S show Is a big Hit." + expected = "Ron's show is a big hit." + result = self.tc.get_true_case(sentence) + assert result == expected + sentence = "What Is Your name?" + expected = "What is your name?" result = self.tc.get_true_case(sentence) + assert result == expected + sentence = "at The moment, I AM getting ready for work!" + expected = "At the moment, I am getting ready for work!" + result = self.tc.get_true_case(sentence) assert result == expected diff --git a/truecase/TrueCaser.py b/truecase/TrueCaser.py index c180d23..1e5c309 100644 --- a/truecase/TrueCaser.py +++ b/truecase/TrueCaser.py @@ -4,7 +4,8 @@ import string import nltk -from nltk.tokenize import TweetTokenizer +from nltk.tokenize import word_tokenize +from nltk.tokenize.treebank import TreebankWordDetokenizer class TrueCaser(object): @@ -22,24 +23,24 @@ def __init__(self, dist_file_path=None): self.forward_bi_dist = pickle_dict["forward_bi_dist"] self.trigram_dist = pickle_dict["trigram_dist"] self.word_casing_lookup = pickle_dict["word_casing_lookup"] - self.tknzr = TweetTokenizer() + self.detknzr = TreebankWordDetokenizer() def get_score(self, prev_token, possible_token, next_token): pseudo_count = 5.0 # Get Unigram Score - nominator = self.uni_dist[possible_token] + pseudo_count + numerator = self.uni_dist[possible_token] + pseudo_count denominator = 0 for alternativeToken in self.word_casing_lookup[ possible_token.lower()]: denominator += self.uni_dist[alternativeToken] + pseudo_count - unigram_score = nominator / denominator + unigram_score = numerator / denominator # Get Backward Score bigram_backward_score = 1 if prev_token is not None: - nominator = ( + numerator = ( self.backward_bi_dist[prev_token + "_" + possible_token] + pseudo_count) denominator = 0 @@ -49,13 +50,13 @@ def get_score(self, prev_token, possible_token, next_token): alternativeToken] + pseudo_count) - bigram_backward_score = nominator / denominator + bigram_backward_score = numerator / denominator # Get Forward Score bigram_forward_score = 1 if next_token is not None: next_token = next_token.lower() # Ensure it is lower case - nominator = ( + numerator = ( self.forward_bi_dist[possible_token + "_" + next_token] + pseudo_count) denominator = 0 @@ -65,13 +66,13 @@ def get_score(self, prev_token, possible_token, next_token): self.forward_bi_dist[alternativeToken + "_" + next_token] + pseudo_count) - bigram_forward_score = nominator / denominator + bigram_forward_score = numerator / denominator # Get Trigram Score trigram_score = 1 if prev_token is not None and next_token is not None: next_token = next_token.lower() # Ensure it is lower case - nominator = (self.trigram_dist[prev_token + "_" + possible_token + + numerator = (self.trigram_dist[prev_token + "_" + possible_token + "_" + next_token] + pseudo_count) denominator = 0 for alternativeToken in self.word_casing_lookup[ @@ -80,7 +81,7 @@ def get_score(self, prev_token, possible_token, next_token): self.trigram_dist[prev_token + "_" + alternativeToken + "_" + next_token] + pseudo_count) - trigram_score = nominator / denominator + trigram_score = numerator / denominator result = (math.log(unigram_score) + math.log(bigram_backward_score) + math.log(bigram_forward_score) + math.log(trigram_score)) @@ -88,7 +89,7 @@ def get_score(self, prev_token, possible_token, next_token): return result def first_token_case(self, raw): - return f'{raw[0].upper()}{raw[1:]}' + return raw.capitalize() def get_true_case(self, sentence, out_of_vocabulary_token_option="title"): """ Returns the true case for the passed tokens. @@ -99,7 +100,7 @@ def get_true_case(self, sentence, out_of_vocabulary_token_option="title"): lower: Returns OOV tokens in lower case as-is: Returns OOV tokens as is """ - tokens = self.tknzr.tokenize(sentence) + tokens = word_tokenize(sentence) tokens_true_case = [] for token_idx, token in enumerate(tokens): @@ -132,21 +133,20 @@ def get_true_case(self, sentence, out_of_vocabulary_token_option="title"): tokens_true_case.append(best_token) if token_idx == 0: - tokens_true_case[0] = self.first_token_case(tokens_true_case[0]) + tokens_true_case[0] = self.first_token_case( + tokens_true_case[0]) else: # Token out of vocabulary if out_of_vocabulary_token_option == "title": tokens_true_case.append(token.title()) + elif out_of_vocabulary_token_option == "capitalize": + tokens_true_case.append(token.capitalize()) elif out_of_vocabulary_token_option == "lower": tokens_true_case.append(token.lower()) else: tokens_true_case.append(token) - return "".join([ - " " + - i if not i.startswith("'") and i not in string.punctuation else i - for i in tokens_true_case - ]).strip() + return self.detknzr.detokenize(tokens_true_case) if __name__ == "__main__": diff --git a/truecase/__init__.py b/truecase/__init__.py index 67a4c19..6428d55 100644 --- a/truecase/__init__.py +++ b/truecase/__init__.py @@ -2,7 +2,7 @@ from .TrueCaser import TrueCaser -__version__ = "0.0.12" +__version__ = "0.0.13" @lru_cache(maxsize=1)