Skip to content

Commit

Permalink
Reimplement test for encoder and model output shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
Waino committed Aug 26, 2024
1 parent 1b01dfd commit 5ee737e
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 138 deletions.
295 changes: 160 additions & 135 deletions mammoth/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,105 @@
import copy
import unittest
import pytest

import torch

import mammoth
import mammoth.opts
from mammoth.model_builder import build_xcoder
from mammoth.model_builder import build_model, build_xcoder
from mammoth.inputters.vocab import Vocab, DEFAULT_SPECIALS
from mammoth.utils.parse import ArgumentParser
from mammoth.distributed.components import (
Side,
DistributedEncoder,
DistributedDecoder,
DistributedEmbedding,
)
from mammoth.distributed.tasks import DatasetMetadata, TaskSpecs, TaskQueueManager, RoundRobinTaskDistributionStrategy
from mammoth.distributed.contexts import WorldContext, DeviceContextEnum

parser = ArgumentParser(description='train.py')
mammoth.opts.model_opts(parser)
mammoth.opts._add_train_general_opts(parser)

# -data option is required, but not used in this test, so dummy.
opts = parser.parse_known_args(
['-tasks', 'dummy', '-node_rank', '0', '-model_dim', '500', '-seed', '1'],
strict=False
)[0]


class TestModel(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestModel, self).__init__(*args, **kwargs)
self.opts = opts

def get_field(self):
return Vocab(None, items=[], tag='dummy', specials=list(DEFAULT_SPECIALS))

def get_batch(self, source_l=3, bsize=1):
# len x batch x nfeat
test_src = torch.ones(source_l, bsize, 1).long()
test_tgt = torch.ones(source_l, bsize, 1).long()
test_length = torch.ones(bsize).fill_(source_l).long()
return test_src, test_tgt, test_length
DEFAULT_ARGS = '-tasks dummy -node_rank 0 -model_dim 500 -seed 1'

VOCABS = {
('src', 'a'): Vocab(None, items=['a'], tag='dummy', specials=list(DEFAULT_SPECIALS)),
('src', 'b'): Vocab(None, items=['b', 'bb'], tag='dummy', specials=list(DEFAULT_SPECIALS)),
('tgt', 'a'): Vocab(None, items=['_a'], tag='dummy', specials=list(DEFAULT_SPECIALS)),
('tgt', 'b'): Vocab(None, items=['_b', '_bb', '_bbb'], tag='dummy', specials=list(DEFAULT_SPECIALS)),
}

TASK_SPECS = {
'dummy_a-b': TaskSpecs(
node_rank=0,
local_rank=0,
src_lang='a',
tgt_lang='b',
encoder_id=['foo'],
decoder_id=['bar'],
corpus_id='a-b',
weight=1,
introduce_at_training_step=0,
corpus_opts=dict(),
src_vocab=VOCABS[('src', 'a')],
tgt_vocab=VOCABS[('tgt', 'b')],
encoder_adapter_ids=None,
decoder_adapter_ids=None,
),
}


class MockGroup:
def __init__(self):
self.group_idx = 0

def __call__(self, sorted_global_ranks):
result = f"Group {self.group_idx} with GPU ranks {sorted_global_ranks}"
self.group_idx += 1
return result


class TestModel():
def __init__(self, args, tasks):
self.opts = self.parse_args(args)
world_context = WorldContext(DeviceContextEnum.MULTI_GPU, n_nodes=1, gpus_per_node=2)
self.tasks = [TASK_SPECS[task] for task in tasks]
self.tqm = TaskQueueManager(
tasks=self.tasks,
accum_count=1,
world_context=world_context,
task_distribution_strategy_cls=RoundRobinTaskDistributionStrategy,
uses_adapters=False,
).global_to_local(
node_rank=0,
local_rank=0,
opts=self.opts,
)
self.vocabs_dict = {
(side, lang): vocab for (side, lang, _, vocab) in self.tqm.get_my_vocabs('src', VOCABS)
}
self.vocabs_dict.update({
(side, lang): vocab for (side, lang, _, vocab) in self.tqm.get_my_vocabs('tgt', VOCABS)
})

self.tqm.create_all_distributed_components(
use_attention_bridge=False, new_group_func=MockGroup()
)

def parse_args(self, args):
opts = parser.parse_known_args(
' '.join([DEFAULT_ARGS, args]).split(),
strict=False
)[0]
return opts

def get_batch(self, source_l=3, bsize=1, task=None):
# x-transformers takes shape (batch, time)
test_src = torch.ones(bsize, source_l).long()
test_tgt = torch.ones(bsize, source_l).long()
test_mask = torch.ones(bsize, source_l).bool()
metadata = task.get_serializable_metadata()
return test_src, test_tgt, test_mask, metadata

# Broken in x-transformers
# def embeddings_forward(self, opts, source_l=3, bsize=1):
Expand All @@ -58,7 +124,7 @@ def get_batch(self, source_l=3, bsize=1):
#
# self.assertEqual(res.size(), compare_to.size())

def encoder_forward(self, opts, source_l=3, bsize=1):
def encoder_forward(self, source_l=3, bsize=1):
'''
Tests if the encoder works as expected
Expand All @@ -67,24 +133,32 @@ def encoder_forward(self, opts, source_l=3, bsize=1):
source_l: Length of generated input sentence
bsize: Batchsize of generated input
'''
word_field = self.get_field()
embeddings = build_embeddings(opts, word_field)
enc = build_encoder(opts, embeddings)
token_embs = None
device = 'cpu'

test_src, test_tgt, test_length = self.get_batch(source_l=source_l, bsize=bsize)
task = self.tasks[0]
test_src, test_tgt, test_mask, metadata = self.get_batch(source_l=source_l, bsize=bsize, task=task)

hidden_t, outputs, test_length = enc(test_src, test_length)
enc = build_xcoder(
Side.encoder,
self.opts,
self.vocabs_dict,
device,
task_queue_manager=self.tqm,
single_task=None,
token_embs=token_embs,
)
active_encoder = enc.activate(task_id=task.corpus_id, adapter_ids=task.encoder_adapter_ids)

# Initialize vectors to compare size with
test_hid = torch.zeros(self.opts.enc_layers, bsize, opts.model_dim)
test_out = torch.zeros(source_l, bsize, opts.model_dim)
encoder_output = active_encoder(test_src, mask=test_mask, return_embeddings=True)

# Ensure correct sizes and types
self.assertEqual(test_hid.size(), hidden_t[0].size(), hidden_t[1].size())
self.assertEqual(test_out.size(), outputs.size())
self.assertEqual(type(outputs), torch.Tensor)
# Make sure that output has the correct size and type
# x-transformers returns (batch, time, dim/vocab_index)
outputsize = torch.zeros(bsize, source_l, self.opts.model_dim)
assert encoder_output.size() == outputsize.size()
assert isinstance(encoder_output, torch.Tensor)

def nmtmodel_forward(self, opts, source_l=3, bsize=1):
def nmtmodel_forward(self, source_l=3, bsize=1):
"""
Creates a nmtmodel with a custom opts function.
Forwards a testbatch and checks output size.
Expand All @@ -94,104 +168,55 @@ def nmtmodel_forward(self, opts, source_l=3, bsize=1):
source_l: length of input sequence
bsize: batchsize
"""
word_field = self.get_field()

embeddings = build_embeddings(opts, word_field)
enc = build_encoder(opts, embeddings)

embeddings = build_embeddings(opts, word_field, for_encoder=False)
dec = build_decoder(opts, embeddings)
model = build_model(
self.opts,
self.opts,
self.vocabs_dict,
task_queue_manager=self.tqm,
single_task=None,
)

task = self.tasks[0]
tgt_vocab = self.vocabs_dict[('tgt', task.tgt_lang)]
test_src, test_tgt, test_mask, metadata = self.get_batch(source_l=source_l, bsize=bsize, task=task)
# currently caller must adjust for the autoregressive step
# shape: (batch, time)
decoder_input = test_tgt[:, :-1]
logits, decoder_output, attentions = model(test_src, decoder_input, test_mask, metadata=metadata)
# Make sure that output has the correct size and type
# x-transformers returns (batch, time, dim/vocab_index)
logitsize = torch.zeros(bsize, source_l - 1, len(tgt_vocab))
outputsize = torch.zeros(bsize, source_l - 1, self.opts.model_dim)
assert logits.size() == logitsize.size()
assert isinstance(logits, torch.Tensor)
assert decoder_output.size() == outputsize.size()
assert isinstance(decoder_output, torch.Tensor)

model = mammoth.models.model.NMTModel(enc, dec)

test_src, test_tgt, test_length = self.get_batch(source_l=source_l, bsize=bsize)
outputs, attn = model(test_src, test_tgt, test_length)
outputsize = torch.zeros(source_l - 1, bsize, opts.model_dim)
# Make sure that output has the correct size and type
self.assertEqual(outputs.size(), outputsize.size())
self.assertEqual(type(outputs), torch.Tensor)


def _add_test(param_setting, methodname):
"""
Adds a Test to TestModel according to settings
Args:
param_setting: list of tuples of (param, setting)
methodname: name of the method that gets called
"""

def test_method(self):
opts = copy.deepcopy(self.opts)
if param_setting:
for param, setting in param_setting:
setattr(opts, param, setting)
ArgumentParser.update_model_opts(opts)
getattr(self, methodname)(opts)

if param_setting:
name = 'test_' + methodname + "_" + "_".join(str(param_setting).split())
else:
name = 'test_' + methodname + '_standard'
setattr(TestModel, name, test_method)
test_method.__name__ = name


'''
TEST PARAMETERS
'''
opts.brnn = False

# FIXME: Most tests disabled: MAMMOTH only supports Transformer
test_embeddings = [
# [],
[('decoder_type', 'transformer')]
]

for p in test_embeddings:
_add_test(p, 'embeddings_forward')

# FIXME: All tests disabled: MAMMOTH only supports Transformer, and the test for Transformer is broken
tests_encoder = [
# [],
# [('encoder_type', 'mean')],
# [('encoder_type', 'transformer'), ('word_vec_size', 16), ('model_dim', 16)],
# [],
]

for p in tests_encoder:
_add_test(p, 'encoder_forward')

# FIXME: Most tests disabled: MAMMOTH only supports Transformer
tests_nmtmodel = [
# [('rnn_type', 'GRU')],
# [('layers', 10)],
# [('input_feed', 0)],
[
('decoder_type', 'transformer'),
('encoder_type', 'transformer'),
('src_word_vec_size', 16),
('tgt_word_vec_size', 16),
('model_dim', 16),
],
@pytest.mark.parametrize(
('args', 'tasks', 'source_l', 'bsize'),
[
('decoder_type', 'transformer'),
('encoder_type', 'transformer'),
('src_word_vec_size', 16),
('tgt_word_vec_size', 16),
('model_dim', 16),
('position_encoding', True),
(
'--enc_layers 1 --dec_layers 1',
['dummy_a-b'],
3,
1,
),
(
'--enc_layers 1 --dec_layers 1',
['dummy_a-b'],
5,
7,
),
(
'--enc_layers 3 --dec_layers 2',
['dummy_a-b'],
4,
1,
),
],
# [('context_gate', 'both')],
# [('context_gate', 'target')],
# [('context_gate', 'source')],
# [('encoder_type', "brnn"), ('brnn_merge', 'sum')],
# [('encoder_type', "brnn")],
# [('decoder_type', 'cnn'), ('encoder_type', 'cnn')],
# [],
]


# ## FIXME: Broken in MAMMOTH
# for p in tests_nmtmodel:
# _add_test(p, 'nmtmodel_forward')
)
def test_nmtmodel(args, tasks, source_l, bsize):
tm = TestModel(args, tasks)
tm.nmtmodel_forward(source_l=source_l, bsize=bsize)
tm.encoder_forward(source_l=source_l, bsize=bsize)
6 changes: 3 additions & 3 deletions mammoth/tests/test_task_queue_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __call__(self, sorted_global_ranks):
DistributedEmbedding(
global_ranks={0, 1},
task_ids={"train_0_a-b", "train_2_a-d"},
group="Group 3 with GPU ranks [0, 1]",
group="Group 2 with GPU ranks [0, 1]",
side=Side.encoder,
lang="a",
),
Expand All @@ -206,7 +206,7 @@ def __call__(self, sorted_global_ranks):
DistributedEmbedding(
global_ranks={0, 2},
task_ids={'train_3_e-b', 'train_0_a-b'},
group="Group 4 with GPU ranks [0, 2]",
group="Group 3 with GPU ranks [0, 2]",
side=Side.decoder,
lang="b",
),
Expand Down Expand Up @@ -268,7 +268,7 @@ def __call__(self, sorted_global_ranks):
DistributedEmbedding(
global_ranks={0, 1},
task_ids={"train_0_a-b", "train_2_a-d"},
group="Group 3 with GPU ranks [0, 1]",
group="Group 2 with GPU ranks [0, 1]",
side=Side.encoder,
lang="a",
),
Expand Down

0 comments on commit 5ee737e

Please sign in to comment.