From 8fd079d9b36216c3ae656917320aa9e0b4f7bc1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 9 Sep 2024 17:07:55 +0300 Subject: [PATCH] WIP: bug in beam search kv cache --- mammoth/tests/test_beam_search.py | 30 +- mammoth/translate/beam_search.py | 5 +- mammoth/translate/decode_strategy.py | 2 +- mammoth/translate/translator.py | 2 +- tools/generate_synth_data.py | 391 +++++++++++++++++++++++++++ tools/iterate_tasks.py | 58 ++++ 6 files changed, 473 insertions(+), 15 deletions(-) create mode 100644 tools/generate_synth_data.py create mode 100644 tools/iterate_tasks.py diff --git a/mammoth/tests/test_beam_search.py b/mammoth/tests/test_beam_search.py index 9af69561..0c2d51ef 100644 --- a/mammoth/tests/test_beam_search.py +++ b/mammoth/tests/test_beam_search.py @@ -61,7 +61,8 @@ def test_advance_with_all_repeats_gets_blocked(self): word_probs[:, repeat_idx] = 0 attns = torch.randn(1, batch_sz * beam_sz, 53) - beam.advance(word_probs, attns) + beam.set_cache(attns) + beam.advance(word_probs) if i < ngram_repeat: # before repeat, scores are either 0 or -inf @@ -130,7 +131,8 @@ def test_advance_with_some_repeats_gets_blocked(self): # continue pushing around what beam 1 predicts word_probs[1::beam_sz, repeat_idx + i + 1] = 0 attns = torch.randn(1, batch_sz * beam_sz, 53) - beam.advance(word_probs, attns) + beam.set_cache(attns) + beam.advance(word_probs) if i < ngram_repeat: self.assertFalse(beam.topk_log_probs[:, 0].eq(self.BLOCKED_SCORE).any()) self.assertFalse(beam.topk_log_probs[:, 1].eq(self.BLOCKED_SCORE).any()) @@ -196,7 +198,8 @@ def test_repeating_excluded_index_does_not_die(self): # predict the allowed-repeat again in beam 2 word_probs[2::beam_sz, repeat_idx_ignored] = 0 attns = torch.randn(1, batch_sz * beam_sz, 53) - beam.advance(word_probs, attns) + beam.set_cache(attns) + beam.advance(word_probs) if i < ngram_repeat: self.assertFalse(beam.topk_log_probs[:, 0].eq(self.BLOCKED_SCORE).any()) self.assertFalse(beam.topk_log_probs[:, 1].eq(self.BLOCKED_SCORE).any()) @@ -245,7 +248,6 @@ def test_doesnt_predict_eos_if_shorter_than_min_len(self): device=device_init.device, ) beam.initialize() - all_attns = [] for i in range(min_length + 4): # non-interesting beams are going to get dummy values word_probs = torch.full((batch_sz * beam_sz, n_words), -float('inf')) @@ -265,16 +267,18 @@ def test_doesnt_predict_eos_if_shorter_than_min_len(self): word_probs[beam_idx::beam_sz, j] = score attns = torch.randn(1, batch_sz * beam_sz, 53) - all_attns.append(attns) - beam.advance(word_probs, attns) + beam.set_cache(attns) + beam.advance(word_probs) if i < min_length: expected_score_dist = (i + 1) * valid_score_dist[1:].unsqueeze(0) # Note that when batch_sz is > 1, expected is broadcast across the batch self.assertTrue(beam.topk_log_probs.allclose(expected_score_dist)) + self.assertTrue(beam.cache.shape == torch.Size([1, batch_sz * beam_sz, 53])) elif i == min_length: # now the top beam has ended and no others have self.assertTrue(beam.is_finished[:, 0].eq(1).all()) self.assertTrue(beam.is_finished[:, 1:].eq(0).all()) + self.assertTrue(beam.cache.shape == torch.Size([1, batch_sz * (beam_sz - 1), 53])) else: # i > min_length # not of interest, but want to make sure it keeps running # since only beam 0 terminates and n_best = 2 @@ -337,7 +341,8 @@ def test_beam_is_done_when_n_best_beams_eos_using_min_length(self): word_probs[beam_idx::beam_sz, j] = score attns = torch.randn(1, batch_sz * beam_sz, 53) - beam.advance(word_probs, attns) + beam.set_cache(attns) + beam.advance(word_probs) if i < min_length: self.assertFalse(beam.done) elif i == min_length: @@ -408,7 +413,8 @@ def test_beam_returns_attn_with_correct_length(self): word_probs[beam_idx::beam_sz, j] = score attns = torch.randn(1, batch_sz * beam_sz, 53) - beam.advance(word_probs, attns) + beam.set_cache(attns) + beam.advance(word_probs) if i < min_length: self.assertFalse(beam.done) # no top beams are finished yet @@ -465,7 +471,7 @@ def init_step(self, beam, expected_len_pen): expected_beam_scores, expected_preds_0 = new_scores.view(self.BATCH_SZ, self.BEAM_SZ * self.N_WORDS).topk( self.BEAM_SZ, dim=-1 ) - beam.advance(deepcopy(init_scores), self.random_attn()) + beam.advance(deepcopy(init_scores)) self.assertTrue(beam.topk_log_probs.allclose(expected_beam_scores)) self.assertTrue(beam.topk_ids.equal(expected_preds_0)) self.assertFalse(beam.is_finished.any()) @@ -489,7 +495,7 @@ def first_step(self, beam, expected_beam_scores, expected_len_pen): ) scores_1 = scores_1.repeat(self.BATCH_SZ, 1) - beam.advance(deepcopy(scores_1), self.random_attn()) + beam.advance(deepcopy(scores_1)) new_scores = scores_1 + expected_beam_scores.view(-1).unsqueeze(1) expected_beam_scores, unreduced_preds = new_scores.view(self.BATCH_SZ, self.BEAM_SZ * self.N_WORDS).topk( @@ -525,7 +531,7 @@ def second_step(self, beam, expected_beam_scores, expected_len_pen): ) scores_2 = scores_2.repeat(self.BATCH_SZ, 1) - beam.advance(deepcopy(scores_2), self.random_attn()) + beam.advance(deepcopy(scores_2)) # ended beam 2 shouldn't continue expected_beam_scores[:, 2::self.BEAM_SZ] = self.DEAD_SCORE @@ -568,7 +574,7 @@ def third_step(self, beam, expected_beam_scores, expected_len_pen): ) scores_3 = scores_3.repeat(self.BATCH_SZ, 1) - beam.advance(deepcopy(scores_3), self.random_attn()) + beam.advance(deepcopy(scores_3)) expected_beam_scores[:, 0::self.BEAM_SZ] = self.DEAD_SCORE new_scores = scores_3 + expected_beam_scores.view(-1).unsqueeze(1) diff --git a/mammoth/translate/beam_search.py b/mammoth/translate/beam_search.py index dd211813..1fd22d0d 100644 --- a/mammoth/translate/beam_search.py +++ b/mammoth/translate/beam_search.py @@ -234,6 +234,9 @@ def update_finished(self): _B_new = non_finished.shape[0] self.remove_finished_batches(_B_new, _B_old, non_finished, predictions, attention, step) + if self.cache is not None: + # FIXME: self.cache is a list of LayerIntermediates. Reach in and manipulate it? + self.cache = None def remove_finished_batches(self, _B_new, _B_old, non_finished, predictions, attention, step): # Remove finished batches for the next step. @@ -253,7 +256,7 @@ def remove_finished_batches(self, _B_new, _B_old, non_finished, predictions, att step - 1, _B_new * self.beam_size, inp_seq_len ) - def advance(self, log_probs, new_cache): + def advance(self, log_probs): vocab_size = log_probs.size(-1) # using integer division to get an integer _B without casting diff --git a/mammoth/translate/decode_strategy.py b/mammoth/translate/decode_strategy.py index 98cd3d84..d5db2d74 100644 --- a/mammoth/translate/decode_strategy.py +++ b/mammoth/translate/decode_strategy.py @@ -273,7 +273,7 @@ def maybe_update_target_prefix(self, select_index): return self.target_prefix = self.target_prefix.index_select(0, select_index) - def advance(self, logits, new_cache): + def advance(self, logits): """DecodeStrategy subclasses should override :func:`advance()`. Advance is used to update ``self.alive_seq``, ``self.is_finished``, diff --git a/mammoth/translate/translator.py b/mammoth/translate/translator.py index 900f125b..ccc85b91 100644 --- a/mammoth/translate/translator.py +++ b/mammoth/translate/translator.py @@ -906,7 +906,7 @@ def _translate_batch_with_strategy(self, batch, src_vocabs, decode_strategy): logits = logits[:, -1] log_probs = torch.log_softmax(logits, dim=-1) - decode_strategy.advance(log_probs, new_cache) + decode_strategy.advance(log_probs) any_finished = decode_strategy.is_finished.any() if any_finished: decode_strategy.update_finished() diff --git a/tools/generate_synth_data.py b/tools/generate_synth_data.py new file mode 100644 index 00000000..686216b5 --- /dev/null +++ b/tools/generate_synth_data.py @@ -0,0 +1,391 @@ +import click +import numpy as np +import os +import yaml +from typing import Optional +from pathlib import Path + +BRACKETS = [ + ('(', ')'), + ('[', ']'), + ('{', '}'), + ('<', '>'), + ('«', '»'), + ('❲', '❳'), +] + +TASKSEP_IDX = -1 +PADDING_IDX = -2 + +TOKEN_MAP = { + TASKSEP_IDX: '->', + PADDING_IDX: '', +} + + +def multi_query_associative_recall( + vocab_size: int, + num_examples: int, + seed: int, + num_kv_pairs: int = 8, + num_queries: int = 4, +): + """ + Adapted from + https://github.com/HazyResearch/zoology/blob/main/zoology/data/associative_recall.py + """ + assert num_kv_pairs * 2 < vocab_size + assert num_queries <= num_kv_pairs + + np.random.seed(seed) + + # two tokens for key and value + context_size = num_kv_pairs * 2 + + # create keys so that each key is present exactly once in each example + key_vocab_size = vocab_size // 2 + key_choices = np.arange(1, key_vocab_size) + value_choices = np.arange(key_vocab_size, vocab_size) + + keys_unshuffled = np.tile(key_choices, (num_examples, 1)) + keys = np.apply_along_axis( + np.random.choice, + 1, + keys_unshuffled, + replace=False, + size=num_kv_pairs + ) + + values_unshuffled = np.tile(value_choices, (num_examples, 1)) + values = np.apply_along_axis( + np.random.choice, + 1, + values_unshuffled, + replace=False, + size=num_kv_pairs + ) + + # create sequences + kvs = np.zeros((num_examples, context_size), dtype=np.int64) + kvs[:, 0::2] = keys + kvs[:, 1::2] = values + + # up to this point follows zoology + + # select a random shuffled index into keys, values + potential_query_indices = np.tile( + np.arange(0, num_kv_pairs), + (num_examples, 1), + ) + selected_query_indices = np.apply_along_axis( + np.random.choice, + 1, + potential_query_indices, + replace=False, + size=num_queries, + ) + query_keys = np.take_along_axis(keys, selected_query_indices, axis=1) + query_values = np.take_along_axis(values, selected_query_indices, axis=1) + + source = np.concatenate( + [kvs, np.full((num_examples, 1), TASKSEP_IDX), query_keys], + axis=1, + ) + + return source, query_values + + +def copy_source( + vocab_size: int, + num_examples: int, + seq_len: int, + seed: int, + distractor_separator: Optional[int] = None, +): + assert vocab_size > seq_len + + np.random.seed(seed) + + token_choices = np.arange(1, vocab_size) + + tokens_unshuffled = np.tile(token_choices, (num_examples, 1)) + tokens = np.apply_along_axis( + np.random.choice, + 1, + tokens_unshuffled, + replace=False, + size=seq_len + ) + if distractor_separator is not None: + source = np.concatenate( + [ + tokens[:, :distractor_separator], + np.full((num_examples, 1), TASKSEP_IDX), + tokens[:, distractor_separator:], + ], + axis=1, + ) + else: + source = tokens + + return source, tokens + + +def reverse_source(*args, **kwargs): + source, target = copy_source(*args, **kwargs) + return source, target[:, ::-1] + + +def sort_source(*args, **kwargs): + source, target = copy_source(*args, **kwargs) + target = np.sort(target, axis=1) + return source, target + + +def counting( + vocab_size: int, + num_examples: int, + max_len: int, + seed: int, +): + np.random.seed(seed) + + token_choices = np.arange(1, vocab_size) + count_choices = np.arange(1, max_len) + selected_tokens = np.random.choice(token_choices, replace=True, size=num_examples) + selected_counts = np.random.choice(count_choices, replace=True, size=num_examples) + + source = np.zeros((num_examples, max_len - 1)) + mask_indices = np.tile(count_choices, (num_examples, 1)) + mask = mask_indices <= selected_counts[:, np.newaxis] + source += selected_tokens[:, np.newaxis] * mask + source += np.full((num_examples, max_len - 1), PADDING_IDX) * ~mask + + target = np.concatenate([ + selected_tokens[:, np.newaxis], + selected_counts[:, np.newaxis], + ], axis=1) + + return source, target + + +def reverse_counting(*args, **kwargs): + target, source = counting(*args, **kwargs) + return source, target + + +def denumericalize(array): + for i in range(array.shape[0]): + tokens = [TOKEN_MAP.get(tok_i, str(int(tok_i))) for tok_i in array[i]] + line = ' '.join(tokens).strip() + yield line + + +def make_vocab( + vocab_path, + vocab_size, + specials=None +): + """ + vocab_path: Path to write out the vocabulary + vocab_size: number of normal tokens + specials: + default None means "use all specials". + To use no specials, pass an empty list. + """ + if specials is None: + vocab = ['', '', ''] + for token in TOKEN_MAP.values(): + if len(token) > 0: + vocab.append(token) + vocab.extend(range(vocab_size)) + with vocab_path.open('w') as fout: + for item in vocab: + print(f'{item}\t0', file=fout) + + +TASK_SPECS = { + 'multi_query_associative_recall_kv6_q2': { + 'func': multi_query_associative_recall, + 'func_args': { + 'num_kv_pairs': 6, + 'num_queries': 2, + }, + }, + 'multi_query_associative_recall_kv20_q4': { + 'func': multi_query_associative_recall, + 'func_args': { + 'num_kv_pairs': 20, + 'num_queries': 4, + }, + }, + 'multi_query_associative_recall_kv12_q8': { + 'func': multi_query_associative_recall, + 'func_args': { + 'num_kv_pairs': 12, + 'num_queries': 8, + }, + }, + 'copy_source': { + 'func': copy_source, + 'func_args': { + 'seq_len': 100, + }, + }, + 'distractor_separator_kv20_q4': { + # The source of this task resembles MQAR. + # Using this task prevents the separator symbol from being associated as a marker of MQAR. + 'func': copy_source, + 'func_args': { + 'seq_len': (2 * 20) + 4, + 'distractor_separator': -4, + }, + }, + 'distractor_separator_kv12_q8': { + 'func': copy_source, + 'func_args': { + 'seq_len': (2 * 12) + 8, + 'distractor_separator': -8, + }, + }, + 'reverse_source': { + 'func': reverse_source, + 'func_args': { + 'seq_len': 100, + }, + }, + 'sort_source': { + 'func': sort_source, + 'func_args': { + 'seq_len': 100, + }, + }, + 'counting': { + 'func': counting, + 'func_args': { + 'max_len': 100, + }, + }, + 'reverse_counting': { + 'func': reverse_counting, + 'func_args': { + 'max_len': 100, + }, + }, +} + + +def parse_config(config_path: Path): + with config_path.open('r') as fin: + config = yaml.safe_load(fin) + src_path_template = config['config_config']['src_path'] + tgt_path_template = config['config_config']['tgt_path'] + valid_src_path_template = config['config_config']['valid_src_path'] + valid_tgt_path_template = config['config_config']['valid_tgt_path'] + vocabs = config['src_vocab'] + + # vocab keys indicate which task param combinations you want to generate + for key, vocab_path in vocabs.items(): + if key not in TASK_SPECS: + raise Exception(f'Requested task {key} not in available tasks {TASK_SPECS.keys()}') + task_specs = TASK_SPECS[key] + paths = { + 'src_path': Path(src_path_template.format(src_lang=key, tgt_lang=key)), + 'tgt_path': Path(tgt_path_template.format(src_lang=key, tgt_lang=key)), + 'valid_src_path': Path(valid_src_path_template.format(src_lang=key, tgt_lang=key)), + 'valid_tgt_path': Path(valid_tgt_path_template.format(src_lang=key, tgt_lang=key)), + 'vocab_path': Path(vocab_path), + } + yield key, task_specs, paths + + +def ensure_parents_exist(paths): + for path in paths: + if path.parent.exists(): + if not path.parent.is_dir(): + raise Exception(f'The parent {path.parent} of path {path} exists, but is not a directory') + else: + os.makedirs(path.parent) + + +def generate_from_config( + config_path: Path, + vocab_size: int, + num_examples_train: int, + num_examples_test: int, + start_seed: int, + shared_vocab: Optional[Path] = None, +): + if shared_vocab: + ensure_parents_exist([shared_vocab]) + make_vocab(shared_vocab, vocab_size) + for i, (key, task_specs, paths) in enumerate(parse_config(config_path)): + print(f'Generating {key}...') + ensure_parents_exist(paths.values()) + args = { + 'vocab_size': vocab_size, + 'num_examples': num_examples_train + num_examples_test, + 'seed': start_seed + i, + } + args.update(task_specs['func_args']) + source, target = task_specs['func'](**args) + if task_specs.get('format', 'numpy') == 'numpy': + source_strs = denumericalize(source) + target_strs = denumericalize(target) + else: + source_strs = source + target_strs = target + with open(paths['src_path'], 'w') as fout_train, \ + open(paths['valid_src_path'], 'w') as fout_test: + for i, line in enumerate(source_strs): + if i < num_examples_train: + print(line, file=fout_train) + else: + print(line, file=fout_test) + with open(paths['tgt_path'], 'w') as fout_train, \ + open(paths['valid_tgt_path'], 'w') as fout_test: + for i, line in enumerate(target_strs): + if i < num_examples_train: + print(line, file=fout_train) + else: + print(line, file=fout_test) + # TODO: make a vocab with custom specials, if not shared_vocab + + +@click.command(context_settings={'show_default': True}) +@click.option('--config_path', type=Path, required=True) +@click.option('--vocab_size', type=int, default=300) +@click.option('--num_examples_train', type=int, default=10000) +@click.option('--num_examples_test', type=int, default=100) +@click.option('--start_seed', type=int, default=1) +@click.option( + '--shared_vocab', + type=Path, + default=None, + help='if specified, outputs a shared vocab to this path. ' + 'if not specified, task specific vocabs are created (TODO).' +) +def main( + config_path: Path, + vocab_size: int, + num_examples_train: int, + num_examples_test: int, + start_seed: int, + shared_vocab: Optional[Path] = None, +): + generate_from_config( + config_path=config_path, + vocab_size=vocab_size, + num_examples_train=num_examples_train, + num_examples_test=num_examples_test, + start_seed=start_seed, + shared_vocab=shared_vocab, + ) + + +if __name__ == '__main__': + main() + + +# other tasks: close brackets (with and without invalid inputs to detect), something else using separator? +# verification (from string?) diff --git a/tools/iterate_tasks.py b/tools/iterate_tasks.py new file mode 100644 index 00000000..f3498661 --- /dev/null +++ b/tools/iterate_tasks.py @@ -0,0 +1,58 @@ +import click +import re +import yaml +from pathlib import Path + + +@click.command() +@click.option('--config', 'config_path', type=Path, help='config file') +@click.option( + '--match', + type=str, + default=None, + help='Regex that task ids must match. Default: include all tasks', +) +@click.option( + '--src', + type=str, + default=None, + help='Template for source file paths. Use varibles src_lang and tgt_lang.', +) +@click.option( + '--output', + type=str, + default=None, + help='Template for source file paths. Use varibles src_lang, tgt_lang, and task_id.', +) +@click.option( + '--flag', + is_flag=True, + help='Prefix with "--task_id". Implied by --src and --output.' +) +def main(config_path, match, src, output, flag): + if src is not None or output is not None: + flag = True + if match: + match = re.compile(match) + with config_path.open('r') as fin: + config = yaml.safe_load(fin) + for key, task in config['tasks'].items(): + if match and not match.match(key): + continue + src_lang, tgt_lang = task['src_tgt'].split('-') + + result = [] + if flag: + result.append('--task_id') + result.append(key) + if src: + task_src = src.format(src_lang=src_lang, tgt_lang=tgt_lang, task_id=key) + result.extend(['--src', task_src]) + if output: + task_out = output.format(src_lang=src_lang, tgt_lang=tgt_lang, task_id=key) + result.extend(['--output', task_out]) + print(' '.join(result)) + + +if __name__ == '__main__': + main()