-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdata.py
82 lines (56 loc) · 2.39 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# coding=utf-8
# /************************************************************************************
# ***
# *** File Author: Dell, 2018年 09月 21日 星期五 10:25:44 CST
# ***
# ************************************************************************************/
import pickle
import torchtext
def text_token(x):
return [w for w in x.split(" ") if len(w) > 0]
FastTextTEXT = torchtext.data.Field(sequential=True, tokenize=text_token, lower=True)
def label_token(x):
return [x.replace("__label__", "")]
FastTextLABEL = torchtext.data.Field(sequential=False, tokenize=label_token, lower=True)
class FastTextDataset(torchtext.data.Dataset):
@staticmethod
def sort_key(ex):
return len(ex.text)
def __init__(self, path, text_field, label_field, sep='\t', **kwargs):
"""Create an dataset instance given a path and fields.
Arguments:
path: Path to the data file.
text_field: The field that will be used for text data.
label_field: The field that will be used for label data.
Remaining keyword arguments: Passed to the constructor of data.Dataset.
"""
fields = [('text', text_field), ('label', label_field)]
examples = []
with open(path, errors='ignore') as f:
for line in f:
s = line.strip().split(sep)
if len(s) != 2:
continue
text, label = s[0], s[1]
label = label.replace("__label__", "")
e = torchtext.data.Example()
setattr(e, "text", text_field.preprocess(text))
setattr(e, "label", label_field.preprocess(label))
examples.append(e)
super(FastTextDataset, self).__init__(examples, fields, **kwargs)
def fasttext_dataloader(datafile, batchsize, shuffle=False):
text_field = FastTextTEXT
label_field = FastTextLABEL
dataset = FastTextDataset(datafile, text_field, label_field)
text_field.build_vocab(dataset)
label_field.build_vocab(dataset)
dataiter = torchtext.data.Iterator(dataset, batchsize, shuffle, repeat=False)
# dataiter.init_epoch()
return dataiter, text_field, label_field
def save_vocab(vocab, filename):
with open(filename, 'wb') as f:
pickle.dump(vocab, f)
def load_vocab(filename):
with open(filename, 'rb') as f:
vocab = pickle.load(f)
return vocab