diff --git a/mammoth/inputters/dataloader.py b/mammoth/inputters/dataloader.py index 67f9bed6..272182c8 100644 --- a/mammoth/inputters/dataloader.py +++ b/mammoth/inputters/dataloader.py @@ -8,7 +8,15 @@ from mammoth.utils.logging import logger -def build_dataloader(dataset, batch_size, batch_type, pool_size=None, n_buckets=None, cycle=True, as_iter=True): +def build_dataloader( + dataset, + batch_size, + batch_type, + max_look_ahead_sentences=None, + lookahead_minibatches=None, + cycle=True, + as_iter=True +): """Convert an mammoth.inputters.ParallelCorpus into an infinite iterator of batches""" if not cycle: loader = InferenceBatcher(dataset, batch_size) @@ -18,8 +26,8 @@ def build_dataloader(dataset, batch_size, batch_type, pool_size=None, n_buckets= elif batch_type == 'tokens': loader = SimpleLookAheadBucketing( dataset=dataset, - max_look_ahead_size=pool_size, - n_buckets=n_buckets, + max_look_ahead_sentences=max_look_ahead_sentences, + lookahead_minibatches=lookahead_minibatches, batch_size=batch_size, score_fn=SimpleLookAheadBucketing.max_of_lens, ) @@ -75,13 +83,13 @@ class SimpleLookAheadBucketing(): """ Arguments: dataset: mammoth.inputters.ParallelCorpus - max_look_ahead_size: + max_look_ahead_sentences: The maximum number of sentence pairs to read before yielding minibatches. Limits the time spent looping if there is a corpus with unexpectedly short sentences. - n_buckets: + lookahead_minibatches: The number of minibatches that will be yielded once bucketing is complete. Recommended value: same as accum_count, or at least a multiple of it. - Setting n_buckets == accum_count means that each accumulated batch uses up the whole buffer. + Setting lookahead_minibatches == accum_count means that each accumulated batch uses up the whole buffer. All tasks stay in sync concerning the length sorting: each task begins with the smallest minibatch and ends with the largest just before accumulation ends. batch_size: @@ -91,13 +99,12 @@ class SimpleLookAheadBucketing(): score_fn: Compute the size estimate (single integer) for sorting examples. """ - def __init__(self, dataset, max_look_ahead_size, n_buckets, batch_size, score_fn=None): + def __init__(self, dataset, max_look_ahead_sentences, lookahead_minibatches, batch_size, score_fn=None): score_fn = score_fn if score_fn else self.max_of_lens self._sie = ScoredInfiniteExamples(dataset, score_fn) - self.max_look_ahead_size = max_look_ahead_size + self.max_look_ahead_sentences = max_look_ahead_sentences self.batch_size = batch_size - self.n_buckets = n_buckets - self.multi_batch_size = n_buckets * batch_size + self.lookahead_minibatches = lookahead_minibatches self.collate_fn = dataset.collate_fn @staticmethod @@ -110,32 +117,42 @@ def max_of_lens(example_dict) -> int: def __iter__(self): while True: - multibatch = [] + maxi_batch = [] max_score = 0 - for i in range(self.max_look_ahead_size): + for i in range(self.max_look_ahead_sentences): score = self._sie.peek_at_score() # Decide whether to add it or not - if len(multibatch) <= self.n_buckets: + if len(maxi_batch) < self.lookahead_minibatches: # Always add at least one example per minibatch still_fits = True else: - still_fits = (max(max_score, score) * (len(multibatch) + 1)) < (self.multi_batch_size) + estimated_minibatch_size = math.ceil((len(maxi_batch) + 1) / self.lookahead_minibatches) + still_fits = (max(max_score, score) * estimated_minibatch_size) < (self.batch_size) + print(f'estimated_minibatch_size {estimated_minibatch_size} still_fits {still_fits}') if still_fits: score, example = self._sie.next() - multibatch.append((score, example)) + maxi_batch.append((score, example)) max_score = max(max_score, score) else: break # Sort by score to reduce padding - multibatch = list(sorted(multibatch, key=lambda x: x[0])) + maxi_batch = list(sorted(maxi_batch, key=lambda x: x[0])) # Split into minibatches and yield - examples_per_batch = math.ceil(len(multibatch) / self.n_buckets) - multibatch_it = iter(multibatch) - for _ in range(self.n_buckets): + floor_examples_per_batch = math.floor(len(maxi_batch) / self.lookahead_minibatches) + examples_per_batch = [floor_examples_per_batch] * self.lookahead_minibatches + for i in range(len(maxi_batch) % self.lookahead_minibatches): + examples_per_batch[i] += 1 + assert all(epb > 0 for epb in examples_per_batch) + assert sum(examples_per_batch) == len(maxi_batch) + print('examples_per_batch', examples_per_batch) + print('maxi', maxi_batch) + maxi_batch_it = iter(maxi_batch) + for epb in examples_per_batch: + print('epb', epb) yield self.collate_fn( [ example_dict for _, example_dict - in itertools.islice(multibatch_it, examples_per_batch) + in itertools.islice(maxi_batch_it, epb) ] ) @@ -153,7 +170,7 @@ class DynamicDatasetIter(object): batch_size (int): numbers of examples in a batch; batch_size_multiple (int): make batch size multiply of this; data_type (str): input data type, currently only text; - pool_size (int): accum this number of examples in a dynamic dataset; + max_look_ahead_sentences (int): accum this number of examples in a dynamic dataset; skip_empty_level (str): security level when encouter empty line; stride (int): iterate data files with this stride; offset (int): iterate data files with this offset. @@ -175,8 +192,8 @@ def __init__( batch_size, batch_size_multiple, data_type="text", - pool_size=2048, - n_buckets=1024, + max_look_ahead_sentences=2048, + lookahead_minibatches=4, skip_empty_level='warning', ): self.task_queue_manager = task_queue_manager @@ -191,8 +208,8 @@ def __init__( self.batch_size = batch_size self.batch_size_multiple = batch_size_multiple self.device = 'cpu' - self.pool_size = pool_size - self.n_buckets = n_buckets + self.max_look_ahead_sentences = max_look_ahead_sentences + self.lookahead_minibatches = lookahead_minibatches if skip_empty_level not in ['silent', 'warning', 'error']: raise ValueError(f"Invalid argument skip_empty_level={skip_empty_level}") self.skip_empty_level = skip_empty_level @@ -216,8 +233,8 @@ def from_opts(cls, task_queue_manager, transforms_cls, vocabs_dict, opts, is_tra batch_size, batch_size_multiple, data_type=opts.data_type, - pool_size=opts.pool_size, - n_buckets=opts.n_buckets, + max_look_ahead_sentences=opts.max_look_ahead_sentences, + lookahead_minibatches=opts.lookahead_minibatches, skip_empty_level=opts.skip_empty_level, ) @@ -251,8 +268,8 @@ def _init_datasets(self): corpus, self.batch_size, self.batch_type, - self.pool_size, - n_buckets=self.n_buckets, + self.max_look_ahead_sentences, + lookahead_minibatches=self.lookahead_minibatches, cycle=self.is_train, as_iter=self.is_train, ) diff --git a/mammoth/opts.py b/mammoth/opts.py index dca0b2e1..cf01649e 100644 --- a/mammoth/opts.py +++ b/mammoth/opts.py @@ -1010,19 +1010,21 @@ def _add_train_general_opts(parser): def _add_train_dynamic_data(parser): group = parser.add_argument_group("Dynamic data") group.add( - "-pool_size", - "--pool_size", + "-lookahead_minibatches", + "--lookahead_minibatches", type=int, - default=2048, - help="(Maximum) number of examples to dynamically pool before batching.", + default=4, + help="The number of minibatches that SimpleLookAheadBucketing will read into a maxibatch, " + "pessimisticly sort by length, split into minibatches, and yield in one go. " + "Recommended value: same as accum_count, or at least a multiple of it." ) group.add( - "-n_buckets", - "--n_buckets", + "-max_look_ahead_sentences", + "--max_look_ahead_sentences", type=int, - default=4, - help="The number of minibatches that will be yielded once bucketing is complete. " - "Recommended value: same as accum_count, or at least a multiple of it." + default=2048, + help="(Maximum) number of sentence pairs that SimpleLookAheadBucketing can attempt to add to the maxibatch. " + "This is mainly a failsafe in case some corpus contains very short examples.", ) diff --git a/mammoth/tests/test_look_ahead_bucketing.py b/mammoth/tests/test_look_ahead_bucketing.py index 45bbd1a4..e74bcf9f 100644 --- a/mammoth/tests/test_look_ahead_bucketing.py +++ b/mammoth/tests/test_look_ahead_bucketing.py @@ -1,6 +1,6 @@ +import pytest from itertools import product -import unittest from mammoth.inputters.dataloader import build_dataloader @@ -26,37 +26,48 @@ def collate_fn(self, items): return items -class TestLookAheadBucketing(unittest.TestCase): - - def test_all_read(self): - max_batch_size = 12 - stream = MockStream([ - hashabledict({ - 'src': tuple([letter for _ in range(i)]), - 'tgt': tuple([letter for _ in range(j)]), - }) - for letter in 'xyz' - for i, j in product(range(1, 11), range(1, 11)) - ]) - lab = build_dataloader( - stream, - batch_size=max_batch_size, - batch_type='tokens', - pool_size=4, - n_buckets=4, - cycle=True, - as_iter=False - ) - examples_read = [] - batches = iter(lab) - for _ in range(1000): - batch = next(batches) - assert len(batch) > 0 - src_toks = sum(len(ex['src']) for ex in batch) - tgt_toks = sum(len(ex['tgt']) for ex in batch) - # check that the batch size is respected - assert src_toks <= max_batch_size - assert tgt_toks <= max_batch_size, str(batch) - examples_read.extend(batch) - # Check that the stream was cycled - self.assertTrue(len(examples_read) > len(stream)) +@pytest.mark.parametrize( + ('max_batch_size', 'lookahead_minibatches'), + [ + (12, 4), + (13, 4), + (14, 4), + (15, 4), + (12, 5), + (13, 5), + (14, 5), + (15, 5), + ], +) +def test_simple_lookeahead_bucketing(max_batch_size, lookahead_minibatches): + stream = MockStream([ + hashabledict({ + 'src': tuple([letter for _ in range(i)]), + 'tgt': tuple([letter for _ in range(j)]), + }) + for letter in 'xyz' + for i, j in product(range(1, 11), range(1, 11)) + ]) + lab = build_dataloader( + stream, + batch_size=max_batch_size, + batch_type='tokens', + max_look_ahead_sentences=512, + lookahead_minibatches=lookahead_minibatches, + cycle=True, + as_iter=False + ) + examples_read = [] + batches = iter(lab) + for _ in range(1000): + batch = next(batches) + print(batch) + assert len(batch) > 0 + src_toks = sum(len(ex['src']) for ex in batch) + tgt_toks = sum(len(ex['tgt']) for ex in batch) + # check that the batch size is respected + assert src_toks <= max_batch_size + assert tgt_toks <= max_batch_size, str(batch) + examples_read.extend(batch) + # Check that the stream was cycled + assert len(examples_read) > len(stream) diff --git a/mammoth/translate/translator.py b/mammoth/translate/translator.py index cb151c10..092d421a 100644 --- a/mammoth/translate/translator.py +++ b/mammoth/translate/translator.py @@ -479,8 +479,8 @@ def _translate( corpus, batch_size=batch_size, batch_type=batch_type, - pool_size=512, - n_buckets=512, + max_look_ahead_sentences=512, + lookahead_minibatches=512, cycle=False, )