Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Opts cleaning #61

Merged
merged 8 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions mammoth/bin/build_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,13 @@ def save_counter(counter, save_path):
for tok, count in counter.most_common():
fo.write(tok + "\t" + str(count) + "\n")

if opts.share_vocab:
raise Exception('--share_vocab not supported')
# src_counter += tgt_counter
# tgt_counter = src_counter
# logger.info(f"Counters after share:{len(src_counter)}")
# save_counter(src_counter, opts.src_vocab)

# TODO: vocab configurability is somewhat limited at the moment.
# Only language-specific vocabs are possible.
# Any attempt at vocab sharing between languages will cause the following to fail.
# Reimplementing --share_vocab may not be optimal
# (it should mean combining all sources and all targets into one big vocab?).
# Perhaps we should gracefully handle the setting where several languages point to the same vocab file?
# UPDATE: --share_vocab removed from flags (issue #60)

for src_lang, src_counter in src_counters_by_lang.items():
logger.info(f"=== Source lang: {src_lang}")
Expand Down
8 changes: 0 additions & 8 deletions mammoth/inputters/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ class DynamicDatasetIter(object):
batch_type (str): batching type to count on, choices=[tokens, sents];
batch_size (int): numbers of examples in a batch;
batch_size_multiple (int): make batch size multiply of this;
data_type (str): input data type, currently only text;
pool_size (int): accum this number of examples in a dynamic dataset;
skip_empty_level (str): security level when encouter empty line;
stride (int): iterate data files with this stride;
Expand All @@ -237,10 +236,8 @@ def __init__(
batch_type,
batch_size,
batch_size_multiple,
data_type="text",
pool_size=2048,
n_buckets=1024,
skip_empty_level='warning',
):
self.task_queue_manager = task_queue_manager
self.opts = opts
Expand All @@ -256,9 +253,6 @@ def __init__(
self.device = 'cpu'
self.pool_size = pool_size
self.n_buckets = n_buckets
if skip_empty_level not in ['silent', 'warning', 'error']:
raise ValueError(f"Invalid argument skip_empty_level={skip_empty_level}")
self.skip_empty_level = skip_empty_level

@classmethod
def from_opts(cls, task_queue_manager, transforms_cls, vocabs_dict, opts, is_train):
Expand All @@ -278,10 +272,8 @@ def from_opts(cls, task_queue_manager, transforms_cls, vocabs_dict, opts, is_tra
opts.batch_type,
batch_size,
batch_size_multiple,
data_type=opts.data_type,
pool_size=opts.pool_size,
n_buckets=opts.n_buckets,
skip_empty_level=opts.skip_empty_level,
)

def _init_datasets(self):
Expand Down
9 changes: 2 additions & 7 deletions mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
EncoderAdapterLayer,
DecoderAdapterLayer,
)
from mammoth.constants import ModelTask, DefaultTokens
from mammoth.constants import DefaultTokens
from mammoth.modules.layer_stack_decoder import LayerStackDecoder
from mammoth.modules.layer_stack_encoder import LayerStackEncoder
from mammoth.modules import Embeddings
Expand Down Expand Up @@ -263,10 +263,7 @@ def create_bilingual_model(

def build_src_emb(model_opts, src_vocab):
# Build embeddings.
if model_opts.model_type == "text":
src_emb = build_embeddings(model_opts, src_vocab)
else:
src_emb = None
src_emb = build_embeddings(model_opts, src_vocab)
return src_emb


Expand All @@ -288,8 +285,6 @@ def build_task_specific_model(
checkpoint,
):
logger.info(f'TaskQueueManager: {task_queue_manager}')
if not model_opts.model_task == ModelTask.SEQ2SEQ:
raise ValueError(f"Only ModelTask.SEQ2SEQ works - {model_opts.model_task} task")

src_embs = dict()
tgt_embs = dict()
Expand Down
Loading
Loading