-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathSentimentClassification_and_WordFrequency.py
69 lines (50 loc) · 1.88 KB
/
SentimentClassification_and_WordFrequency.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# %%
import matplotlib.pyplot as plt
import nltk
from nltk.corpus import stopwords
from transformers import AutoTokenizer, TFDistilBertForSequenceClassification
import string
nltk.download("stopwords")
model = TFDistilBertForSequenceClassification.from_pretrained(
"/saved_model")
# get the tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
# %%
# tokenize input text
input_ids = tokenizer(
'This is my my dog. I have some cute adorable pictures of my dog. Dog has a cute smile. I love my dog', return_tensors='tf')
# get tokens as a list of words
tokens = tokenizer.convert_ids_to_tokens(input_ids['input_ids'][0])
# predict the label
preds = model(input_ids)
# %%
# show the prediction result
sentiment = model.config.id2label[preds[0][0].numpy().argmax()]
# %%
# filter out the only words that are not stopwords, punctuation, or numbers, cls token, and pad token
stwrds = stopwords.words('english')
# remove cls and sep tokens
filtered_words = [word for word in tokens if word not in [
'[CLS]', '[SEP]', stwrds, string.punctuation, string.digits, '.']]
# plot frequency distribution of words with frequency greater than 1
freq = nltk.FreqDist(filtered_words)
# filter set where value is more than 1
new_set = [(sub, val) for sub, val in freq.items() if val > 1]
top = freq.most_common(4)
# %%
# bold the words that are most common in the original text
for sub, val in new_set:
tokens = [word if word != sub else '**' + word + '**' for word in tokens]
# final ouptut
new_text = tokenizer.convert_tokens_to_string(tokens)
# filter cls and sep tokens
new_text = new_text.replace('[CLS]', '').replace('[SEP]', '')
# %%
print("**********************************\n")
print("Sentiment of sentence is: ", sentiment)
print("\n")
print("The text with most frequent words bolded is: \n")
print(new_text)
print("\n")
print("The top 4 most frequent words are: \n")
print(top)