From e8aa749a754c97dddb4355ef483f3a97658195ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 26 Aug 2024 16:02:02 +0300 Subject: [PATCH] Reimplement test for encoder and model output shapes --- mammoth/tests/test_models.py | 297 ++++++++++++----------- mammoth/tests/test_task_queue_manager.py | 6 +- 2 files changed, 163 insertions(+), 140 deletions(-) diff --git a/mammoth/tests/test_models.py b/mammoth/tests/test_models.py index 1cd787b7..7c9a8350 100644 --- a/mammoth/tests/test_models.py +++ b/mammoth/tests/test_models.py @@ -1,39 +1,105 @@ -import copy -import unittest +import pytest import torch import mammoth import mammoth.opts -from mammoth.model_builder import build_xcoder +from mammoth.model_builder import build_model, build_xcoder from mammoth.inputters.vocab import Vocab, DEFAULT_SPECIALS from mammoth.utils.parse import ArgumentParser +from mammoth.distributed.components import ( + Side, + DistributedEncoder, + DistributedDecoder, + DistributedEmbedding, +) +from mammoth.distributed.tasks import DatasetMetadata, TaskSpecs, TaskQueueManager, RoundRobinTaskDistributionStrategy +from mammoth.distributed.contexts import WorldContext, DeviceContextEnum parser = ArgumentParser(description='train.py') mammoth.opts.model_opts(parser) mammoth.opts._add_train_general_opts(parser) -# -data option is required, but not used in this test, so dummy. -opts = parser.parse_known_args( - ['-tasks', 'dummy', '-node_rank', '0', '-model_dim', '500', '-seed', '1'], - strict=False -)[0] - - -class TestModel(unittest.TestCase): - def __init__(self, *args, **kwargs): - super(TestModel, self).__init__(*args, **kwargs) - self.opts = opts - - def get_field(self): - return Vocab(None, items=[], tag='dummy', specials=list(DEFAULT_SPECIALS)) - - def get_batch(self, source_l=3, bsize=1): - # len x batch x nfeat - test_src = torch.ones(source_l, bsize, 1).long() - test_tgt = torch.ones(source_l, bsize, 1).long() - test_length = torch.ones(bsize).fill_(source_l).long() - return test_src, test_tgt, test_length +DEFAULT_ARGS = '-tasks dummy -node_rank 0 -model_dim 500 -seed 1' + +VOCABS = { + ('src', 'a'): Vocab(None, items=['a'], tag='dummy', specials=list(DEFAULT_SPECIALS)), + ('src', 'b'): Vocab(None, items=['b', 'bb'], tag='dummy', specials=list(DEFAULT_SPECIALS)), + ('tgt', 'a'): Vocab(None, items=['_a'], tag='dummy', specials=list(DEFAULT_SPECIALS)), + ('tgt', 'b'): Vocab(None, items=['_b', '_bb', '_bbb'], tag='dummy', specials=list(DEFAULT_SPECIALS)), +} + +TASK_SPECS = { + 'dummy_a-b': TaskSpecs( + node_rank=0, + local_rank=0, + src_lang='a', + tgt_lang='b', + encoder_id=['foo'], + decoder_id=['bar'], + corpus_id='a-b', + weight=1, + introduce_at_training_step=0, + corpus_opts=dict(), + src_vocab=VOCABS[('src', 'a')], + tgt_vocab=VOCABS[('tgt', 'b')], + encoder_adapter_ids=None, + decoder_adapter_ids=None, + ), +} + + +class MockGroup: + def __init__(self): + self.group_idx = 0 + + def __call__(self, sorted_global_ranks): + result = f"Group {self.group_idx} with GPU ranks {sorted_global_ranks}" + self.group_idx += 1 + return result + + +class TestModel(): + def __init__(self, args, tasks): + self.opts = self.parse_args(args) + world_context = WorldContext(DeviceContextEnum.MULTI_GPU, n_nodes=1, gpus_per_node=2) + self.tasks = [TASK_SPECS[task] for task in tasks] + self.tqm = TaskQueueManager( + tasks=self.tasks, + accum_count=1, + world_context=world_context, + task_distribution_strategy_cls=RoundRobinTaskDistributionStrategy, + uses_adapters=False, + ).global_to_local( + node_rank=0, + local_rank=0, + opts=self.opts, + ) + self.vocabs_dict = { + (side, lang): vocab for (side, lang, _, vocab) in self.tqm.get_my_vocabs('src', VOCABS) + } + self.vocabs_dict.update({ + (side, lang): vocab for (side, lang, _, vocab) in self.tqm.get_my_vocabs('tgt', VOCABS) + }) + + self.tqm.create_all_distributed_components( + use_attention_bridge=False, new_group_func=MockGroup() + ) + + def parse_args(self, args): + opts = parser.parse_known_args( + ' '.join([DEFAULT_ARGS, args]).split(), + strict=False + )[0] + return opts + + def get_batch(self, source_l=3, bsize=1, task=None): + # x-transformers takes shape (batch, time) + test_src = torch.ones(bsize, source_l).long() + test_tgt = torch.ones(bsize, source_l).long() + test_mask = torch.ones(bsize, source_l).bool() + metadata = task.get_serializable_metadata() + return test_src, test_tgt, test_mask, metadata # Broken in x-transformers # def embeddings_forward(self, opts, source_l=3, bsize=1): @@ -58,140 +124,97 @@ def get_batch(self, source_l=3, bsize=1): # # self.assertEqual(res.size(), compare_to.size()) - def encoder_forward(self, opts, source_l=3, bsize=1): + def encoder_forward(self, source_l=3, bsize=1): ''' Tests if the encoder works as expected args: - opts: set of options source_l: Length of generated input sentence bsize: Batchsize of generated input ''' - word_field = self.get_field() - embeddings = build_embeddings(opts, word_field) - enc = build_encoder(opts, embeddings) + token_embs = None + device = 'cpu' - test_src, test_tgt, test_length = self.get_batch(source_l=source_l, bsize=bsize) + task = self.tasks[0] + test_src, test_tgt, test_mask, metadata = self.get_batch(source_l=source_l, bsize=bsize, task=task) - hidden_t, outputs, test_length = enc(test_src, test_length) + enc = build_xcoder( + Side.encoder, + self.opts, + self.vocabs_dict, + device, + task_queue_manager=self.tqm, + single_task=None, + token_embs=token_embs, + ) + active_encoder = enc.activate(task_id=task.corpus_id, adapter_ids=task.encoder_adapter_ids) - # Initialize vectors to compare size with - test_hid = torch.zeros(self.opts.enc_layers, bsize, opts.model_dim) - test_out = torch.zeros(source_l, bsize, opts.model_dim) + encoder_output = active_encoder(test_src, mask=test_mask, return_embeddings=True) - # Ensure correct sizes and types - self.assertEqual(test_hid.size(), hidden_t[0].size(), hidden_t[1].size()) - self.assertEqual(test_out.size(), outputs.size()) - self.assertEqual(type(outputs), torch.Tensor) + # Make sure that output has the correct size and type + # x-transformers returns (batch, time, dim/vocab_index) + outputsize = torch.zeros(bsize, source_l, self.opts.model_dim) + assert encoder_output.size() == outputsize.size() + assert isinstance(encoder_output, torch.Tensor) - def nmtmodel_forward(self, opts, source_l=3, bsize=1): + def nmtmodel_forward(self, source_l=3, bsize=1): """ Creates a nmtmodel with a custom opts function. Forwards a testbatch and checks output size. Args: - opts: Namespace with options source_l: length of input sequence bsize: batchsize """ - word_field = self.get_field() - - embeddings = build_embeddings(opts, word_field) - enc = build_encoder(opts, embeddings) - - embeddings = build_embeddings(opts, word_field, for_encoder=False) - dec = build_decoder(opts, embeddings) + model = build_model( + self.opts, + self.opts, + self.vocabs_dict, + task_queue_manager=self.tqm, + single_task=None, + ) + + task = self.tasks[0] + tgt_vocab = self.vocabs_dict[('tgt', task.tgt_lang)] + test_src, test_tgt, test_mask, metadata = self.get_batch(source_l=source_l, bsize=bsize, task=task) + # currently caller must adjust for the autoregressive step + # shape: (batch, time) + decoder_input = test_tgt[:, :-1] + logits, decoder_output, attentions = model(test_src, decoder_input, test_mask, metadata=metadata) + # Make sure that output has the correct size and type + # x-transformers returns (batch, time, dim/vocab_index) + logitsize = torch.zeros(bsize, source_l - 1, len(tgt_vocab)) + outputsize = torch.zeros(bsize, source_l - 1, self.opts.model_dim) + assert logits.size() == logitsize.size() + assert isinstance(logits, torch.Tensor) + assert decoder_output.size() == outputsize.size() + assert isinstance(decoder_output, torch.Tensor) - model = mammoth.models.model.NMTModel(enc, dec) - test_src, test_tgt, test_length = self.get_batch(source_l=source_l, bsize=bsize) - outputs, attn = model(test_src, test_tgt, test_length) - outputsize = torch.zeros(source_l - 1, bsize, opts.model_dim) - # Make sure that output has the correct size and type - self.assertEqual(outputs.size(), outputsize.size()) - self.assertEqual(type(outputs), torch.Tensor) - - -def _add_test(param_setting, methodname): - """ - Adds a Test to TestModel according to settings - - Args: - param_setting: list of tuples of (param, setting) - methodname: name of the method that gets called - """ - - def test_method(self): - opts = copy.deepcopy(self.opts) - if param_setting: - for param, setting in param_setting: - setattr(opts, param, setting) - ArgumentParser.update_model_opts(opts) - getattr(self, methodname)(opts) - - if param_setting: - name = 'test_' + methodname + "_" + "_".join(str(param_setting).split()) - else: - name = 'test_' + methodname + '_standard' - setattr(TestModel, name, test_method) - test_method.__name__ = name - - -''' -TEST PARAMETERS -''' -opts.brnn = False - -# FIXME: Most tests disabled: MAMMOTH only supports Transformer -test_embeddings = [ - # [], - [('decoder_type', 'transformer')] -] - -for p in test_embeddings: - _add_test(p, 'embeddings_forward') - -# FIXME: All tests disabled: MAMMOTH only supports Transformer, and the test for Transformer is broken -tests_encoder = [ - # [], - # [('encoder_type', 'mean')], - # [('encoder_type', 'transformer'), ('word_vec_size', 16), ('model_dim', 16)], - # [], -] - -for p in tests_encoder: - _add_test(p, 'encoder_forward') - -# FIXME: Most tests disabled: MAMMOTH only supports Transformer -tests_nmtmodel = [ - # [('rnn_type', 'GRU')], - # [('layers', 10)], - # [('input_feed', 0)], - [ - ('decoder_type', 'transformer'), - ('encoder_type', 'transformer'), - ('src_word_vec_size', 16), - ('tgt_word_vec_size', 16), - ('model_dim', 16), - ], +@pytest.mark.parametrize( + ('args', 'tasks', 'source_l', 'bsize'), [ - ('decoder_type', 'transformer'), - ('encoder_type', 'transformer'), - ('src_word_vec_size', 16), - ('tgt_word_vec_size', 16), - ('model_dim', 16), - ('position_encoding', True), + ( + '--enc_layers 1 --dec_layers 1', + ['dummy_a-b'], + 3, + 1, + ), + ( + '--enc_layers 1 --dec_layers 1', + ['dummy_a-b'], + 5, + 7, + ), + ( + '--enc_layers 3 --dec_layers 2', + ['dummy_a-b'], + 4, + 1, + ), ], - # [('context_gate', 'both')], - # [('context_gate', 'target')], - # [('context_gate', 'source')], - # [('encoder_type', "brnn"), ('brnn_merge', 'sum')], - # [('encoder_type', "brnn")], - # [('decoder_type', 'cnn'), ('encoder_type', 'cnn')], - # [], -] - - -# ## FIXME: Broken in MAMMOTH -# for p in tests_nmtmodel: -# _add_test(p, 'nmtmodel_forward') +) +def test_nmtmodel(args, tasks, source_l, bsize): + tm = TestModel(args, tasks) + tm.nmtmodel_forward(source_l=source_l, bsize=bsize) + tm.encoder_forward(source_l=source_l, bsize=bsize) diff --git a/mammoth/tests/test_task_queue_manager.py b/mammoth/tests/test_task_queue_manager.py index 29e6a4ed..30df3202 100644 --- a/mammoth/tests/test_task_queue_manager.py +++ b/mammoth/tests/test_task_queue_manager.py @@ -193,7 +193,7 @@ def __call__(self, sorted_global_ranks): DistributedEmbedding( global_ranks={0, 1}, task_ids={"train_0_a-b", "train_2_a-d"}, - group="Group 3 with GPU ranks [0, 1]", + group="Group 2 with GPU ranks [0, 1]", side=Side.encoder, lang="a", ), @@ -206,7 +206,7 @@ def __call__(self, sorted_global_ranks): DistributedEmbedding( global_ranks={0, 2}, task_ids={'train_3_e-b', 'train_0_a-b'}, - group="Group 4 with GPU ranks [0, 2]", + group="Group 3 with GPU ranks [0, 2]", side=Side.decoder, lang="b", ), @@ -268,7 +268,7 @@ def __call__(self, sorted_global_ranks): DistributedEmbedding( global_ranks={0, 1}, task_ids={"train_0_a-b", "train_2_a-d"}, - group="Group 3 with GPU ranks [0, 1]", + group="Group 2 with GPU ranks [0, 1]", side=Side.encoder, lang="a", ),