-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdistilbert_hate_detection.py
41 lines (27 loc) · 1.04 KB
/
distilbert_hate_detection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""
Developed by Aindriya Barua in October, 2021
"""
from simpletransformers.classification import ClassificationModel, ClassificationArgs
import pandas as pd
import logging
from sklearn.model_selection import train_test_split
import pickle
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)
def read_data(path):
dataset_df = pd.read_excel(path, dtype=object)
print(dataset_df)
return dataset_df
dataset_df = read_data('codemix-hin-hate.xlsx')
dataset_df = dataset_df[dataset_df['label'].notna()]
print(dataset_df)
train_df, eval_df= train_test_split(dataset_df, test_size=0.20, random_state=42)
model_args = ClassificationArgs(num_train_epochs=1)
model = ClassificationModel(
"distilbert", "distilbert-base-uncased", args=model_args, use_cuda = False
)
model.train_model(train_df)
result, model_outputs, wrong_predictions = model.eval_model(eval_df)
print(result, model_outputs, wrong_predictions)
pickle.dump(model, open("model" + '.pkl', 'wb'))