Skip to content

Commit

Permalink
Merge pull request #48 from Helsinki-NLP/embeddingless
Browse files Browse the repository at this point in the history
Implemented option to replace embedding vectors by non-trainable 1-hot encoding vectors
  • Loading branch information
josephattieh authored Feb 6, 2024
2 parents dfbe9ef + e87f39e commit 632feed
Show file tree
Hide file tree
Showing 18 changed files with 106 additions and 68 deletions.
12 changes: 6 additions & 6 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@
}

mathjax3_config = {
"tex": {
"inlineMath": [['\\(', '\\)']],
"displayMath": [["\\[", "\\]"]],
}
"tex": {
"inlineMath": [['\\(', '\\)']],
"displayMath": [["\\[", "\\]"]],
}
}

bibtex_bibfiles = ['refs.bib']
Expand Down Expand Up @@ -144,8 +144,8 @@
html_context = {
'css_files': [
'_static/theme_overrides.css', # override wide tables in RTD theme
],
}
],
}

# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
Expand Down
2 changes: 1 addition & 1 deletion mammoth/distributed/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def __init__(
"""
self.tasks = tasks
# TODO: no support for variable accumulation across training
self.accum_count = accum_count[0] if type(accum_count) is list else accum_count
self.accum_count = accum_count[0] if isinstance(accum_count, list) else accum_count
self.task_distribution_strategy = task_distribution_strategy
self.world_context = world_context
self.device_context = device_context
Expand Down
5 changes: 3 additions & 2 deletions mammoth/inputters/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def numel_fn(example_dict):

class InferenceBatcher():
"""Iterator for inference"""

def __init__(self, dataset, batch_size):
self.examples_stream = dataset
self.collate_fn = dataset.collate_fn
Expand Down Expand Up @@ -346,8 +347,8 @@ def __iter__(self):
if communication_batch_id == 0:
# De-numericalize a few sentences for debugging
logger.warning(
f'src shape: {batch.src[0].shape} tgt shape: {batch.tgt.shape} '
f'batch size: {batch.batch_size}'
f'src shape: {batch.src[0].shape} tgt shape: {batch.tgt.shape} '
f'batch size: {batch.batch_size}'
)
src_vocab = self.vocabs_dict[('src', metadata.src_lang)]
tgt_vocab = self.vocabs_dict[('tgt', metadata.tgt_lang)]
Expand Down
13 changes: 7 additions & 6 deletions mammoth/inputters/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def _make_example_dict(packed):

class ParallelCorpus(IterableDataset):
"""Torch-style dataset"""

def __init__(
self,
src_file,
Expand Down Expand Up @@ -245,12 +246,12 @@ def build_vocab_counts(opts, corpus_id, transforms, n_sample=3):

corpora = {
corpus_id: read_examples_from_files(
opts.tasks[corpus_id]["path_src"],
opts.tasks[corpus_id]["path_tgt"],
# FIXME this is likely not working
transforms_fn=TransformPipe(transforms).apply if transforms else lambda x: x,
)
}
opts.tasks[corpus_id]["path_src"],
opts.tasks[corpus_id]["path_tgt"],
# FIXME this is likely not working
transforms_fn=TransformPipe(transforms).apply if transforms else lambda x: x,
)
}
counter_src = collections.Counter()
counter_tgt = collections.Counter()

Expand Down
38 changes: 20 additions & 18 deletions mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,17 @@ def build_embeddings(opts, vocab, for_encoder=True):
opts.word_padding_idx = word_padding_idx

freeze_word_vecs = opts.freeze_word_vecs_enc if for_encoder else opts.freeze_word_vecs_dec

emb = Embeddings(
word_vec_size=opts.model_dim,
position_encoding=opts.position_encoding,
dropout=opts.dropout[0] if type(opts.dropout) is list else opts.dropout,
dropout=opts.dropout[0] if isinstance(opts.dropout, list) else opts.dropout,
word_padding_idx=word_padding_idx,
word_vocab_size=len(vocab),
freeze_word_vecs=freeze_word_vecs,
enable_embeddingless=opts.enable_embeddingless
)
if opts.enable_embeddingless:
logger.info("Creating an embeddingless model.")
return emb


Expand Down Expand Up @@ -152,7 +154,7 @@ def load_test_multitask_model(opts, task=None, model_path=None):
task=task,
model_opts=model_opts,
vocabs_dict=vocabs_dict
)
)
model_params = {name for name, p in model.named_parameters()}
model_params.update(name for name, p in model.named_buffers())
for key in set(combined_state_dict.keys()):
Expand Down Expand Up @@ -228,7 +230,6 @@ def create_bilingual_model(
src_lang = task.src_lang
tgt_lang = task.tgt_lang
generators_md = nn.ModuleDict()

src_emb = build_src_emb(model_opts, vocabs_dict['src'])
tgt_emb = build_tgt_emb(model_opts, vocabs_dict['tgt'])
pluggable_src_emb = PluggableEmbeddings({src_lang: src_emb})
Expand All @@ -248,15 +249,13 @@ def create_bilingual_model(
decoder=decoder,
attention_bridge=attention_bridge
)

if uses_adapters(model_opts):
logger.info('Creating adapters...')
create_bilingual_adapters(nmt_model, model_opts, task)
else:
logger.info('Does not use adapters...')
print('built model:')
print(nmt_model)

nmt_model.generator = generators_md
return nmt_model

Expand Down Expand Up @@ -300,7 +299,6 @@ def build_task_specific_model(
for side, lang, _, vocab in task_queue_manager.get_vocabs(side='src', vocabs_dict=vocabs_dict):
src_emb = build_src_emb(model_opts, vocab)
src_embs[lang] = src_emb

pluggable_src_emb = PluggableEmbeddings(src_embs)
encoder = build_only_enc(model_opts, pluggable_src_emb, task_queue_manager)

Expand Down Expand Up @@ -366,12 +364,15 @@ def build_only_enc(model_opts, src_emb, task_queue_manager):
"""Truly only builds encoder: no embeddings"""
encoder = build_encoder(model_opts, src_emb, task_queue_manager)
if model_opts.param_init != 0.0:
for p in encoder.parameters():
p.data.uniform_(-model_opts.param_init, model_opts.param_init)
for name, p in encoder.named_parameters():
if not ("embedding" in name and "pe" not in name and model_opts.enable_embeddingless is True):
p.data.uniform_(-model_opts.param_init, model_opts.param_init)

if model_opts.param_init_glorot:
for p in encoder.parameters():
if p.dim() > 1:
xavier_uniform_(p, gain=nn.init.calculate_gain('relu'))
for name, p in encoder.named_parameters():
if not ("embedding" in name and "pe" not in name and model_opts.enable_embeddingless is True):
if p.dim() > 1:
xavier_uniform_(p, gain=nn.init.calculate_gain('relu'))
if model_opts.model_dtype == 'fp16' and model_opts.optim == 'fusedadam':
encoder.half()

Expand All @@ -380,14 +381,15 @@ def build_only_enc(model_opts, src_emb, task_queue_manager):

def build_only_dec(model_opts, tgt_emb, task_queue_manager):
decoder = build_decoder(model_opts, tgt_emb, task_queue_manager)

if model_opts.param_init != 0.0:
for p in decoder.parameters():
p.data.uniform_(-model_opts.param_init, model_opts.param_init)
for name, p in decoder.named_parameters():
if not ("embedding" in name and "pe" not in name and model_opts.enable_embeddingless is True):
p.data.uniform_(-model_opts.param_init, model_opts.param_init)
if model_opts.param_init_glorot:
for p in decoder.parameters():
if p.dim() > 1:
xavier_uniform_(p, gain=nn.init.calculate_gain('relu'))
for name, p in decoder.named_parameters():
if not ("embedding" in name and "pe" not in name and model_opts.enable_embeddingless is True):
if p.dim() > 1:
xavier_uniform_(p, gain=nn.init.calculate_gain('relu'))

if model_opts.model_dtype == 'fp16' and model_opts.optim == 'fusedadam':
decoder.half()
Expand Down
7 changes: 5 additions & 2 deletions mammoth/models/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class AdapterLayer(ABC, nn.Module):
See also fairseq implementation:
https://github.com/ahmetustun/fairseq/blob/master/fairseq/modules/adapter_layer.py
"""

def __init__(self, input_dim, hidden_dim, pfeiffer=False, init='small', layernorm='layernorm'):
super().__init__()
# Omit LayerCache
Expand All @@ -43,8 +44,8 @@ def __init__(self, input_dim, hidden_dim, pfeiffer=False, init='small', layernor

def init_fn(tensor):
nn.init.uniform_(
tensor,
almost_zero - delta, almost_zero + delta
tensor,
almost_zero - delta, almost_zero + delta
)
elif init == 'bert':

Expand Down Expand Up @@ -104,6 +105,7 @@ class Adapter(nn.Module):
A container for one or several AdapterLayers,
together with layer indices for injecting into the base network.
"""

def __init__(self, adapter_group: str, sub_id: str):
super().__init__()
self.name = self._name(adapter_group, sub_id)
Expand Down Expand Up @@ -136,6 +138,7 @@ class TransformerAdapterMixin:
Mixin to manage one or several Adapters
for a TransformerEncoder or TransformerDecoder.
"""

def __init__(self, *args, **kwargs):
# run init of next parallel inheritance class
super(TransformerAdapterMixin, self).__init__(*args, **kwargs)
Expand Down
21 changes: 20 additions & 1 deletion mammoth/modules/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from mammoth.modules.util_class import Elementwise
# from mammoth.utils.logging import logger

import torch.nn.functional as F
# import bitsandbytes as bnb


Expand Down Expand Up @@ -116,6 +117,7 @@ def __init__(
feat_vocab_sizes=[],
dropout=0,
freeze_word_vecs=False,
enable_embeddingless=False
):
self._validate_args(feat_merge, feat_vocab_sizes, feat_vec_exponent, feat_vec_size, feat_padding_idx)

Expand Down Expand Up @@ -145,7 +147,24 @@ def __init__(
# The embedding matrix look-up tables. The first look-up table
# is for words. Subsequent ones are for features, if any exist.
emb_params = zip(vocab_sizes, emb_dims, pad_indices)
embeddings = [nn.Embedding(vocab, dim, padding_idx=pad) for vocab, dim, pad in emb_params]

emb_params = zip(vocab_sizes, emb_dims, pad_indices)
if enable_embeddingless is False:
embeddings = [nn.Embedding(vocab, dim, padding_idx=pad) for vocab, dim, pad in emb_params]

else:

def create_embeddingless(vocab, dim, padding_idx):
one_hot_matrix = F.one_hot(torch.arange(vocab)).float()
one_hot_embed = torch.cat((one_hot_matrix, torch.zeros((vocab, dim - vocab))), dim=1)
one_hot_embed[padding_idx] = torch.zeros(dim).unsqueeze(0)
emb = nn.Embedding(vocab, dim, padding_idx=padding_idx)
emb.weight = torch.nn.parameter.Parameter(one_hot_embed, requires_grad=False)
return emb
embeddings = [
create_embeddingless(vocab, dim, padding_idx=pad)
for vocab, dim, pad in emb_params
]
emb_luts = Elementwise(feat_merge, embeddings)

# The final output size of word + feature vectors. This can vary
Expand Down
8 changes: 4 additions & 4 deletions mammoth/modules/layer_stack_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ def from_opts(cls, opts, embeddings, task_queue_manager, is_on_top=False):
opts.transformer_ff,
opts.copy_attn,
opts.self_attn_type,
opts.dropout[0] if type(opts.dropout) is list else opts.dropout,
opts.dropout[0] if isinstance(opts.dropout, list) else opts.dropout,
(
opts.attention_dropout[0]
if type(opts.attention_dropout) is list
if isinstance(opts.attention_dropout, list)
else opts.attention_dropout
),
None, # embeddings,
Expand Down Expand Up @@ -71,10 +71,10 @@ def from_trans_opt(cls, opts, embeddings, task, is_on_top=False):
opts.transformer_ff,
opts.copy_attn,
opts.self_attn_type,
opts.dropout[0] if type(opts.dropout) is list else opts.dropout,
opts.dropout[0] if isinstance(opts.dropout, list) else opts.dropout,
(
opts.attention_dropout[0]
if type(opts.attention_dropout) is list
if isinstance(opts.attention_dropout, list)
else opts.attention_dropout
),
None, # embeddings,
Expand Down
8 changes: 4 additions & 4 deletions mammoth/modules/layer_stack_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def from_opts(cls, opts, embeddings, task_queue_manager):
opts.model_dim,
opts.heads,
opts.transformer_ff,
opts.dropout[0] if type(opts.dropout) is list else opts.dropout,
opts.dropout[0] if isinstance(opts.dropout, list) else opts.dropout,
(
opts.attention_dropout[0]
if type(opts.attention_dropout) is list
if isinstance(opts.attention_dropout, list)
else opts.attention_dropout
),
None, # embeddings,
Expand Down Expand Up @@ -63,10 +63,10 @@ def from_trans_opt(cls, opts, embeddings, task):
opts.model_dim,
opts.heads,
opts.transformer_ff,
opts.dropout[0] if type(opts.dropout) is list else opts.dropout,
opts.dropout[0] if isinstance(opts.dropout, list) else opts.dropout,
(
opts.attention_dropout[0]
if type(opts.attention_dropout) is list
if isinstance(opts.attention_dropout, list)
else opts.attention_dropout
),
None, # embeddings,
Expand Down
4 changes: 2 additions & 2 deletions mammoth/modules/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,8 @@ def from_opts(cls, opts, embeddings, is_on_top=False):
opts.transformer_ff,
opts.copy_attn,
opts.self_attn_type,
opts.dropout[0] if type(opts.dropout) is list else opts.dropout,
opts.attention_dropout[0] if type(opts.attention_dropout) is list else opts.attention_dropout,
opts.dropout[0] if isinstance(opts.dropout, list) else opts.dropout,
opts.attention_dropout[0] if isinstance(opts.attention_dropout, list) else opts.attention_dropout,
embeddings,
opts.max_relative_positions,
opts.aan_useffn,
Expand Down
4 changes: 2 additions & 2 deletions mammoth/modules/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ def from_opts(cls, opts, embeddings, is_on_top=False):
opts.model_dim,
opts.heads,
opts.transformer_ff,
opts.dropout[0] if type(opts.dropout) is list else opts.dropout,
opts.attention_dropout[0] if type(opts.attention_dropout) is list else opts.attention_dropout,
opts.dropout[0] if isinstance(opts.dropout, list) else opts.dropout,
opts.attention_dropout[0] if isinstance(opts.attention_dropout, list) else opts.attention_dropout,
embeddings,
opts.max_relative_positions,
pos_ffn_activation_fn=opts.pos_ffn_activation_fn,
Expand Down
9 changes: 8 additions & 1 deletion mammoth/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,13 @@ def model_opts(parser):
"and decoder. Need to use shared dictionary for this "
"option.",
)
group.add(
'--enable_embeddingless',
'-enable_embeddingless',
action='store_true',
help="Enable the use of byte-based embeddingless models" +
"(Shaham et. al, 2021) https://aclanthology.org/2021.naacl-main.17/",
)
group.add(
'--position_encoding',
'-position_encoding',
Expand Down Expand Up @@ -1102,7 +1109,7 @@ def _add_decoding_opts(parser):
)
# Decoding Length constraint
group.add('--min_length', '-min_length', type=int, default=0, help='Minimum prediction length')
group.add('--max_length', '-max_length', type=int, default=100, help='Maximum prediction length.')
group.add('--max_length', '-max_length', type=int, default=5000, help='Maximum prediction length.')
group.add(
'--max_sent_length', '-max_sent_length', action=DeprecateAction, help="Deprecated, use `-max_length` instead"
)
Expand Down
4 changes: 2 additions & 2 deletions mammoth/tests/test_look_ahead_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def test_all_read(self):
def test_reroutes(self):
stream = MockStream([hashabledict({'src': '_', 'tgt': '_'})] * 10)
lab = build_dataloader(stream, 2, 'tokens', 4, 2, cycle=True, as_iter=False)
self.assertTrue(type(lab) is LookAheadBucketing)
self.assertTrue(isinstance(lab, LookAheadBucketing))
not_lab = build_dataloader(stream, 2, 'tokens', 4, 2, cycle=False, as_iter=False)
self.assertTrue(type(not_lab) is InferenceBatcher)
self.assertTrue(isinstance(not_lab, InferenceBatcher))

def test_always_continues(self):
stream = MockStream([hashabledict({'src': '_', 'tgt': '_'})] * 10)
Expand Down
Loading

0 comments on commit 632feed

Please sign in to comment.