-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataset.py
48 lines (39 loc) · 1.36 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from torch.utils.data import Dataset
intent_list=sorted(["normal","hatespeech","offensive"])
def format_data(raw):
global intent_list
texts = []
labels = []
# intent_list = sorted(list(raw.keys()))
for label, data in raw.items():
for text in data:
texts.append(text)
labels.append(intent_list.index(label))
return texts, labels
def get_dataset(tokenizer, dataset):
train_text, train_class = format_data(dataset["train"])
test_text, test_class = format_data(dataset["test"])
return (
Transform_Dataset(train_text, train_class, tokenizer),
Transform_Dataset(test_text, test_class, tokenizer),
)
class Transform_Dataset(Dataset):
def __init__(self, text, label, tokenizer):
self.text = text
self.tokenizer = tokenizer
self.label = label
self.unique_labels=intent_list
def __len__(self):
return len(self.text)
def __getitem__(self, idx):
item = self.tokenizer(
self.text[idx], max_length=80, truncation=True, padding="max_length"
)
item["label"] = self.label[idx]
return item
@property
def id2label(self):
return dict(enumerate(self.unique_labels))
@property
def label2id(self):
return {v: k for k, v in self.id2label.items()}