-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloader.py
178 lines (146 loc) · 6.09 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# standard imports
import os
import codecs
from random import shuffle
from typing import List, Tuple
from itertools import groupby
# third-party imports
import youtokentome
import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences
class SequenceLoader(object):
"""
An iterator for loading batches of data into the Transformer.
For training:
Each batch contains tokens_in_batch target language tokens (approximately),
target language sequences of the same length to minimize padding and therefore memory usage,
source language sequences of very similar lengths to minimize padding and therefore memory usage.
Batches are also shuffled.
For validation and testing:
Each batch contains a single source-target pair, in the same order as in the files from which they were read.
"""
def __init__(
self,
data_folder: str,
source_suffix: str,
target_suffix: str,
split: str,
tokens_in_batch: int,
):
"""
Sequence constructor that creates batches of sequences for a given data folder.
:param data_folder: folder containing the source and target language data files
:param source_suffix: the filename suffix for the source language
:param target_suffix: the filename suffix for the target language
:param split: train, val or test
:param tokens_in_batch: the number of target language tokens in each batch
"""
self.n_batches = None
self.all_batches = None
self.current_batch = None
self.tokens_in_batch = tokens_in_batch
self.source_suffix = source_suffix
self.target_suffix = target_suffix
assert split.lower() in {"train", "test", "val"}, "no such split"
self.split = split.lower()
self.for_training = self.split == "train"
self.bpe_model = youtokentome.BPE(model=os.path.join(data_folder, "bpe.model"))
with codecs.open(
os.path.join(data_folder, ".".join([split, source_suffix])),
"r",
encoding="utf-8",
) as f:
source_data = f.read().split("\n")
with codecs.open(
os.path.join(data_folder, ".".join([split, target_suffix])),
"r",
encoding="utf-8",
) as f:
target_data = f.read().split("\n")
assert len(source_data) == len(
target_data
), "different number of source and target sequences"
source_lengths = [
len(s) for s in self.bpe_model.encode(source_data, bos=False, eos=False)
]
target_lengths = [
len(t) for t in self.bpe_model.encode(target_data, bos=True, eos=True)
]
self.data = list(zip(source_data, target_data, source_lengths, target_lengths))
if self.for_training:
self.data.sort(key=lambda x: x[3])
self.create_batches()
def create_batches(self):
"""
Prepares batches for one epoch.
"""
if self.for_training:
chunks = [list(g) for _, g in groupby(self.data, key=lambda x: x[3])]
self.all_batches = (
list()
) # create batches with the same target sequence length
for chunk in chunks:
chunk.sort(
key=lambda x: x[2]
) # sort so that a batch would also have similar source sequence lengths
# div the expected batch size (tokens) by sequence length in this chunk to get # of sequences per batch
seqs_per_batch = self.tokens_in_batch // chunk[0][3]
self.all_batches.extend(
[
chunk[i: i + seqs_per_batch]
for i in range(0, len(chunk), seqs_per_batch)
]
)
shuffle(self.all_batches) # shuffle batches
self.n_batches = len(self.all_batches)
self.current_batch = -1
else:
self.all_batches = [[d] for d in self.data]
self.n_batches = len(self.all_batches)
self.current_batch = -1
def __iter__(self):
"""Required by iterator."""
return self
def get_vocabulary(self) -> List[str]:
"""
Returns a list of all unique tokens in the vocabulary.
:returns: a list of strings
"""
return self.bpe_model.vocab()
def __next__(self) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
"""
Next in iterator.
:returns: the next batch, containing:
source language sequences, a tensor of size (N, encoder_sequence_pad_length)
target language sequences, a tensor of size (N, decoder_sequence_pad_length)
true source language lengths, a tensor of size (N)
true target language lengths, a tensor of size (N)
"""
self.current_batch += 1
try:
source_data, target_data, source_lengths, target_lengths = zip(
*self.all_batches[self.current_batch]
)
except IndexError:
raise StopIteration
source_data = self.bpe_model.encode(
source_data, output_type=youtokentome.OutputType.ID, bos=False, eos=False
)
target_data = self.bpe_model.encode(
target_data, output_type=youtokentome.OutputType.ID, bos=True, eos=True
)
source_data = pad_sequences(
sequences=source_data,
padding="post",
value=self.bpe_model.subword_to_id("<PAD>"),
)
target_data = pad_sequences(
sequences=target_data,
padding="post",
value=self.bpe_model.subword_to_id("<PAD>"),
)
source_data = tf.convert_to_tensor(source_data, dtype=tf.int32)
target_data = tf.convert_to_tensor(target_data, dtype=tf.int32)
source_lengths = tf.convert_to_tensor(source_lengths, dtype=tf.int32)
target_lengths = tf.convert_to_tensor(target_lengths, dtype=tf.int32)
return source_data, target_data, source_lengths, target_lengths