Skip to content

Commit

Permalink
Bugfixes and rename parameters of SimpleLookAheadBucketing
Browse files Browse the repository at this point in the history
Due to a rounding bug in determining the size of the minibatch, it was
possible for minibatches to slightly exceed the specified size.
Now the estimation is slightly more pessimistic, and the guarantee
holds.
  • Loading branch information
Waino committed May 13, 2024
1 parent 0964b0a commit 2328fd3
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 75 deletions.
75 changes: 46 additions & 29 deletions mammoth/inputters/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@
from mammoth.utils.logging import logger


def build_dataloader(dataset, batch_size, batch_type, pool_size=None, n_buckets=None, cycle=True, as_iter=True):
def build_dataloader(
dataset,
batch_size,
batch_type,
max_look_ahead_sentences=None,
lookahead_minibatches=None,
cycle=True,
as_iter=True
):
"""Convert an mammoth.inputters.ParallelCorpus into an infinite iterator of batches"""
if not cycle:
loader = InferenceBatcher(dataset, batch_size)
Expand All @@ -18,8 +26,8 @@ def build_dataloader(dataset, batch_size, batch_type, pool_size=None, n_buckets=
elif batch_type == 'tokens':
loader = SimpleLookAheadBucketing(
dataset=dataset,
max_look_ahead_size=pool_size,
n_buckets=n_buckets,
max_look_ahead_sentences=max_look_ahead_sentences,
lookahead_minibatches=lookahead_minibatches,
batch_size=batch_size,
score_fn=SimpleLookAheadBucketing.max_of_lens,
)
Expand Down Expand Up @@ -75,13 +83,13 @@ class SimpleLookAheadBucketing():
"""
Arguments:
dataset: mammoth.inputters.ParallelCorpus
max_look_ahead_size:
max_look_ahead_sentences:
The maximum number of sentence pairs to read before yielding minibatches.
Limits the time spent looping if there is a corpus with unexpectedly short sentences.
n_buckets:
lookahead_minibatches:
The number of minibatches that will be yielded once bucketing is complete.
Recommended value: same as accum_count, or at least a multiple of it.
Setting n_buckets == accum_count means that each accumulated batch uses up the whole buffer.
Setting lookahead_minibatches == accum_count means that each accumulated batch uses up the whole buffer.
All tasks stay in sync concerning the length sorting: each task begins with the smallest
minibatch and ends with the largest just before accumulation ends.
batch_size:
Expand All @@ -91,13 +99,12 @@ class SimpleLookAheadBucketing():
score_fn:
Compute the size estimate (single integer) for sorting examples.
"""
def __init__(self, dataset, max_look_ahead_size, n_buckets, batch_size, score_fn=None):
def __init__(self, dataset, max_look_ahead_sentences, lookahead_minibatches, batch_size, score_fn=None):
score_fn = score_fn if score_fn else self.max_of_lens
self._sie = ScoredInfiniteExamples(dataset, score_fn)
self.max_look_ahead_size = max_look_ahead_size
self.max_look_ahead_sentences = max_look_ahead_sentences
self.batch_size = batch_size
self.n_buckets = n_buckets
self.multi_batch_size = n_buckets * batch_size
self.lookahead_minibatches = lookahead_minibatches
self.collate_fn = dataset.collate_fn

@staticmethod
Expand All @@ -110,32 +117,42 @@ def max_of_lens(example_dict) -> int:

def __iter__(self):
while True:
multibatch = []
maxi_batch = []
max_score = 0
for i in range(self.max_look_ahead_size):
for i in range(self.max_look_ahead_sentences):
score = self._sie.peek_at_score()
# Decide whether to add it or not
if len(multibatch) <= self.n_buckets:
if len(maxi_batch) < self.lookahead_minibatches:
# Always add at least one example per minibatch
still_fits = True
else:
still_fits = (max(max_score, score) * (len(multibatch) + 1)) < (self.multi_batch_size)
estimated_minibatch_size = math.ceil((len(maxi_batch) + 1) / self.lookahead_minibatches)
still_fits = (max(max_score, score) * estimated_minibatch_size) < (self.batch_size)
print(f'estimated_minibatch_size {estimated_minibatch_size} still_fits {still_fits}')
if still_fits:
score, example = self._sie.next()
multibatch.append((score, example))
maxi_batch.append((score, example))
max_score = max(max_score, score)
else:
break
# Sort by score to reduce padding
multibatch = list(sorted(multibatch, key=lambda x: x[0]))
maxi_batch = list(sorted(maxi_batch, key=lambda x: x[0]))
# Split into minibatches and yield
examples_per_batch = math.ceil(len(multibatch) / self.n_buckets)
multibatch_it = iter(multibatch)
for _ in range(self.n_buckets):
floor_examples_per_batch = math.floor(len(maxi_batch) / self.lookahead_minibatches)
examples_per_batch = [floor_examples_per_batch] * self.lookahead_minibatches
for i in range(len(maxi_batch) % self.lookahead_minibatches):
examples_per_batch[i] += 1
assert all(epb > 0 for epb in examples_per_batch)
assert sum(examples_per_batch) == len(maxi_batch)
print('examples_per_batch', examples_per_batch)
print('maxi', maxi_batch)
maxi_batch_it = iter(maxi_batch)
for epb in examples_per_batch:
print('epb', epb)
yield self.collate_fn(
[
example_dict for _, example_dict
in itertools.islice(multibatch_it, examples_per_batch)
in itertools.islice(maxi_batch_it, epb)
]
)

Expand All @@ -153,7 +170,7 @@ class DynamicDatasetIter(object):
batch_size (int): numbers of examples in a batch;
batch_size_multiple (int): make batch size multiply of this;
data_type (str): input data type, currently only text;
pool_size (int): accum this number of examples in a dynamic dataset;
max_look_ahead_sentences (int): accum this number of examples in a dynamic dataset;
skip_empty_level (str): security level when encouter empty line;
stride (int): iterate data files with this stride;
offset (int): iterate data files with this offset.
Expand All @@ -175,8 +192,8 @@ def __init__(
batch_size,
batch_size_multiple,
data_type="text",
pool_size=2048,
n_buckets=1024,
max_look_ahead_sentences=2048,
lookahead_minibatches=4,
skip_empty_level='warning',
):
self.task_queue_manager = task_queue_manager
Expand All @@ -191,8 +208,8 @@ def __init__(
self.batch_size = batch_size
self.batch_size_multiple = batch_size_multiple
self.device = 'cpu'
self.pool_size = pool_size
self.n_buckets = n_buckets
self.max_look_ahead_sentences = max_look_ahead_sentences
self.lookahead_minibatches = lookahead_minibatches
if skip_empty_level not in ['silent', 'warning', 'error']:
raise ValueError(f"Invalid argument skip_empty_level={skip_empty_level}")
self.skip_empty_level = skip_empty_level
Expand All @@ -216,8 +233,8 @@ def from_opts(cls, task_queue_manager, transforms_cls, vocabs_dict, opts, is_tra
batch_size,
batch_size_multiple,
data_type=opts.data_type,
pool_size=opts.pool_size,
n_buckets=opts.n_buckets,
max_look_ahead_sentences=opts.max_look_ahead_sentences,
lookahead_minibatches=opts.lookahead_minibatches,
skip_empty_level=opts.skip_empty_level,
)

Expand Down Expand Up @@ -251,8 +268,8 @@ def _init_datasets(self):
corpus,
self.batch_size,
self.batch_type,
self.pool_size,
n_buckets=self.n_buckets,
self.max_look_ahead_sentences,
lookahead_minibatches=self.lookahead_minibatches,
cycle=self.is_train,
as_iter=self.is_train,
)
Expand Down
20 changes: 11 additions & 9 deletions mammoth/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,19 +1010,21 @@ def _add_train_general_opts(parser):
def _add_train_dynamic_data(parser):
group = parser.add_argument_group("Dynamic data")
group.add(
"-pool_size",
"--pool_size",
"-lookahead_minibatches",
"--lookahead_minibatches",
type=int,
default=2048,
help="(Maximum) number of examples to dynamically pool before batching.",
default=4,
help="The number of minibatches that SimpleLookAheadBucketing will read into a maxibatch, "
"pessimisticly sort by length, split into minibatches, and yield in one go. "
"Recommended value: same as accum_count, or at least a multiple of it."
)
group.add(
"-n_buckets",
"--n_buckets",
"-max_look_ahead_sentences",
"--max_look_ahead_sentences",
type=int,
default=4,
help="The number of minibatches that will be yielded once bucketing is complete. "
"Recommended value: same as accum_count, or at least a multiple of it."
default=2048,
help="(Maximum) number of sentence pairs that SimpleLookAheadBucketing can attempt to add to the maxibatch. "
"This is mainly a failsafe in case some corpus contains very short examples.",
)


Expand Down
81 changes: 46 additions & 35 deletions mammoth/tests/test_look_ahead_bucketing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from itertools import product

import unittest
from mammoth.inputters.dataloader import build_dataloader


Expand All @@ -26,37 +26,48 @@ def collate_fn(self, items):
return items


class TestLookAheadBucketing(unittest.TestCase):

def test_all_read(self):
max_batch_size = 12
stream = MockStream([
hashabledict({
'src': tuple([letter for _ in range(i)]),
'tgt': tuple([letter for _ in range(j)]),
})
for letter in 'xyz'
for i, j in product(range(1, 11), range(1, 11))
])
lab = build_dataloader(
stream,
batch_size=max_batch_size,
batch_type='tokens',
pool_size=4,
n_buckets=4,
cycle=True,
as_iter=False
)
examples_read = []
batches = iter(lab)
for _ in range(1000):
batch = next(batches)
assert len(batch) > 0
src_toks = sum(len(ex['src']) for ex in batch)
tgt_toks = sum(len(ex['tgt']) for ex in batch)
# check that the batch size is respected
assert src_toks <= max_batch_size
assert tgt_toks <= max_batch_size, str(batch)
examples_read.extend(batch)
# Check that the stream was cycled
self.assertTrue(len(examples_read) > len(stream))
@pytest.mark.parametrize(
('max_batch_size', 'lookahead_minibatches'),
[
(12, 4),
(13, 4),
(14, 4),
(15, 4),
(12, 5),
(13, 5),
(14, 5),
(15, 5),
],
)
def test_simple_lookeahead_bucketing(max_batch_size, lookahead_minibatches):
stream = MockStream([
hashabledict({
'src': tuple([letter for _ in range(i)]),
'tgt': tuple([letter for _ in range(j)]),
})
for letter in 'xyz'
for i, j in product(range(1, 11), range(1, 11))
])
lab = build_dataloader(
stream,
batch_size=max_batch_size,
batch_type='tokens',
max_look_ahead_sentences=512,
lookahead_minibatches=lookahead_minibatches,
cycle=True,
as_iter=False
)
examples_read = []
batches = iter(lab)
for _ in range(1000):
batch = next(batches)
print(batch)
assert len(batch) > 0
src_toks = sum(len(ex['src']) for ex in batch)
tgt_toks = sum(len(ex['tgt']) for ex in batch)
# check that the batch size is respected
assert src_toks <= max_batch_size
assert tgt_toks <= max_batch_size, str(batch)
examples_read.extend(batch)
# Check that the stream was cycled
assert len(examples_read) > len(stream)
4 changes: 2 additions & 2 deletions mammoth/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,8 @@ def _translate(
corpus,
batch_size=batch_size,
batch_type=batch_type,
pool_size=512,
n_buckets=512,
max_look_ahead_sentences=512,
lookahead_minibatches=512,
cycle=False,
)

Expand Down

0 comments on commit 2328fd3

Please sign in to comment.