From 6ec3fe0cd96f25d7a61861ac7edc4eb8d27a22b6 Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Fri, 8 Mar 2024 11:18:18 +0200 Subject: [PATCH 1/7] first pass --- mammoth/bin/build_vocab.py | 8 +- mammoth/inputters/dataloader.py | 5 - mammoth/opts.py | 393 ++++++++--------------------- mammoth/tests/test_data_prepare.py | 7 - mammoth/transforms/tokenize.py | 1 + mammoth/utils/parse.py | 23 +- 6 files changed, 113 insertions(+), 324 deletions(-) diff --git a/mammoth/bin/build_vocab.py b/mammoth/bin/build_vocab.py index 65f408b1..069839ab 100644 --- a/mammoth/bin/build_vocab.py +++ b/mammoth/bin/build_vocab.py @@ -67,19 +67,13 @@ def save_counter(counter, save_path): for tok, count in counter.most_common(): fo.write(tok + "\t" + str(count) + "\n") - if opts.share_vocab: - raise Exception('--share_vocab not supported') - # src_counter += tgt_counter - # tgt_counter = src_counter - # logger.info(f"Counters after share:{len(src_counter)}") - # save_counter(src_counter, opts.src_vocab) - # TODO: vocab configurability is somewhat limited at the moment. # Only language-specific vocabs are possible. # Any attempt at vocab sharing between languages will cause the following to fail. # Reimplementing --share_vocab may not be optimal # (it should mean combining all sources and all targets into one big vocab?). # Perhaps we should gracefully handle the setting where several languages point to the same vocab file? + # UPDATE: --share_vocab removed from flags (issue #60) for src_lang, src_counter in src_counters_by_lang.items(): logger.info(f"=== Source lang: {src_lang}") diff --git a/mammoth/inputters/dataloader.py b/mammoth/inputters/dataloader.py index 719802cf..ac9ec452 100644 --- a/mammoth/inputters/dataloader.py +++ b/mammoth/inputters/dataloader.py @@ -240,7 +240,6 @@ def __init__( data_type="text", pool_size=2048, n_buckets=1024, - skip_empty_level='warning', ): self.task_queue_manager = task_queue_manager self.opts = opts @@ -256,9 +255,6 @@ def __init__( self.device = 'cpu' self.pool_size = pool_size self.n_buckets = n_buckets - if skip_empty_level not in ['silent', 'warning', 'error']: - raise ValueError(f"Invalid argument skip_empty_level={skip_empty_level}") - self.skip_empty_level = skip_empty_level @classmethod def from_opts(cls, task_queue_manager, transforms_cls, vocabs_dict, opts, is_train): @@ -281,7 +277,6 @@ def from_opts(cls, task_queue_manager, transforms_cls, vocabs_dict, opts, is_tra data_type=opts.data_type, pool_size=opts.pool_size, n_buckets=opts.n_buckets, - skip_empty_level=opts.skip_empty_level, ) def _init_datasets(self): diff --git a/mammoth/opts.py b/mammoth/opts.py index 0c599aac..3e26a61e 100644 --- a/mammoth/opts.py +++ b/mammoth/opts.py @@ -1,7 +1,6 @@ """ Implementation of all available options """ import configargparse -from mammoth.constants import ModelTask from mammoth.modules.position_ffn import ACTIVATION_FUNCTIONS from mammoth.modules.position_ffn import ActivationFunction from mammoth.transforms import AVAILABLE_TRANSFORMS @@ -67,7 +66,7 @@ def _add_logging_opts(parser, is_train=True): ) group.add( '--report_stats_from_parameters', - '-report_stats_from_parameters=', + '-report_stats_from_parameters', action="store_true", help="Report parameter-level statistics in tensorboard. " "This has a huge impact on performance: only use for debugging.", @@ -107,16 +106,6 @@ def _add_dynamic_corpus_opts(parser, build_vocab_only=False): required=True, help="List of datasets and their specifications. See examples/*.yaml for further details.", ) - group.add( - "-skip_empty_level", - "--skip_empty_level", - default="warning", - choices=["silent", "warning", "error"], - help="Security level when encounter empty examples." - "silent: silently ignore/skip empty example;" - "warning: warning when ignore/skip empty example;" - "error: raise error & stop execution when encouter empty.", - ) group.add( "-transforms", "--transforms", @@ -196,84 +185,38 @@ def _add_dynamic_vocabs_opts(parser, build_vocab_only=False): help=("Path to save" if build_vocab_only else "Path to") + " tgt vocabulary file. " "Format: one or \t per line.", ) - group.add("-share_vocab", "--share_vocab", action="store_true", help="Share source and target vocabulary.") - group.add( - "-vocab_paths", - "--vocab_paths", - default=None, - help="file name with ENCorDEC TAB language name TAB path of the vocab.", - ) - - group.add( - "-src_feats_vocab", - "--src_feats_vocab", - help=("List of paths to save" if build_vocab_only else "List of paths to") + " src features vocabulary files. " - "Files format: one or \t per line.", - ) + # Moved to transform + # group.add("-share_vocab", "--share_vocab", action="store_true", help="Share source and target vocabulary.") if not build_vocab_only: group.add( "-src_vocab_size", "--src_vocab_size", type=int, - default=50000, - help="Maximum size of the source vocabulary.", - ) - group.add( - "-tgt_vocab_size", "--tgt_vocab_size", type=int, default=50000, help="Maximum size of the target vocabulary" - ) - group.add( - "-vocab_size_multiple", - "--vocab_size_multiple", - type=int, - default=1, - help="Make the vocabulary size a multiple of this value.", - ) - - group.add( - "-src_words_min_frequency", - "--src_words_min_frequency", - type=int, - default=0, - help="Discard source words with lower frequency.", - ) - group.add( - "-tgt_words_min_frequency", - "--tgt_words_min_frequency", - type=int, - default=0, - help="Discard target words with lower frequency.", - ) - - # Truncation options, for text corpus - group = parser.add_argument_group("Pruning") - group.add( - "--src_seq_length_trunc", - "-src_seq_length_trunc", - type=int, default=None, - help="Truncate source sequence length.", + help="Maximum size of the source vocabulary; will silently truncate your vocab file if longer.", ) group.add( - "--tgt_seq_length_trunc", - "-tgt_seq_length_trunc", + "-tgt_vocab_size", + "--tgt_vocab_size", type=int, default=None, - help="Truncate target sequence length.", + help="Maximum size of the target vocabulary; will silently truncate your vocab file if longer." ) - group = parser.add_argument_group('Embeddings') - group.add( - '-both_embeddings', - '--both_embeddings', - help="Path to the embeddings file to use for both source and target tokens.", - ) - group.add('-src_embeddings', '--src_embeddings', help="Path to the embeddings file to use for source tokens.") - group.add('-tgt_embeddings', '--tgt_embeddings', help="Path to the embeddings file to use for target tokens.") - group.add( - '-embeddings_type', '--embeddings_type', choices=["GloVe", "word2vec"], help="Type of embeddings file." - ) + # FIXME: nuked in the great refactor. Commenting out for now (issue #60) + # group = parser.add_argument_group('Embeddings') + # group.add( + # '-both_embeddings', + # '--both_embeddings', + # help="Path to the embeddings file to use for both source and target tokens.", + # ) + # group.add('-src_embeddings', '--src_embeddings', help="Path to the embeddings file to use for source tokens.") + # group.add('-tgt_embeddings', '--tgt_embeddings', help="Path to the embeddings file to use for target tokens.") + # group.add( + # '-embeddings_type', '--embeddings_type', choices=["GloVe", "word2vec"], help="Type of embeddings file." + # ) def _add_dynamic_transform_opts(parser): @@ -310,7 +253,7 @@ def model_opts(parser): """ # Embedding Options - group = parser.add_argument_group('Model-Embeddings') + group = parser.add_argument_group('Model- Embeddings') group.add( '--share_decoder_embeddings', @@ -318,14 +261,6 @@ def model_opts(parser): action='store_true', help="Use a shared weight matrix for the input and output word embeddings in the decoder.", ) - group.add( - '--share_embeddings', - '-share_embeddings', - action='store_true', - help="Share the word embeddings between encoder " - "and decoder. Need to use shared dictionary for this " - "option.", - ) group.add( '--enable_embeddingless', '-enable_embeddingless', @@ -339,57 +274,9 @@ def model_opts(parser): action='store_true', help="Use a sin to mark relative words positions. Necessary for non-RNN style models.", ) - group.add( - "-update_vocab", "--update_vocab", action="store_true", help="Update source and target existing vocabularies" - ) - - group = parser.add_argument_group('Model-Embedding Features') - group.add( - '--feat_merge', - '-feat_merge', - type=str, - default='concat', - choices=['concat', 'sum', 'mlp'], - help="Merge action for incorporating features embeddings. Options [concat|sum|mlp].", - ) - group.add( - '--feat_vec_size', - '-feat_vec_size', - type=int, - default=-1, - help="If specified, feature embedding sizes " - "will be set to this. Otherwise, feat_vec_exponent " - "will be used.", - ) - group.add( - '--feat_vec_exponent', - '-feat_vec_exponent', - type=float, - default=0.7, - help="If -feat_merge_size is not set, feature " - "embedding sizes will be set to N^feat_vec_exponent " - "where N is the number of values the feature takes.", - ) - - # Model Task Options - group = parser.add_argument_group("Model- Task") - group.add( - "-model_task", - "--model_task", - default=ModelTask.SEQ2SEQ, - choices=[ModelTask.SEQ2SEQ, ModelTask.LANGUAGE_MODEL], - help="Type of task for the model either seq2seq or lm", - ) # Encoder-Decoder Options group = parser.add_argument_group('Model- Encoder-Decoder') - group.add( - '--model_type', - '-model_type', - default='text', - choices=['text'], - help="Type of source model to use. Allows the system to incorporate non-text inputs. Options are [text].", - ) group.add('--model_dtype', '-model_dtype', default='fp32', choices=['fp32', 'fp16'], help='Data type of the model.') group.add( @@ -397,7 +284,7 @@ def model_opts(parser): '-encoder_type', type=str, default='transformer', - choices=['mean', 'transformer'], + choices=['mean', 'transformer'], # TODO is this mean actually supported? help="Type of encoder layer to use. Non-RNN layers " "are experimental. Options are " "[mean|transformer].", @@ -413,7 +300,7 @@ def model_opts(parser): "[transformer].", ) - group.add('--layers', '-layers', type=int, default=-1, help='Deprecated') + # group.add('--layers', '-layers', type=int, default=-1, help='Deprecated') group.add('--enc_layers', '-enc_layers', nargs='+', type=int, help='Number of layers in each encoder') group.add('--dec_layers', '-dec_layers', nargs='+', type=int, help='Number of layers in each decoder') group.add( @@ -424,14 +311,6 @@ def model_opts(parser): help="Size of rnn hidden states.", ) - # group.add( - # '--cnn_kernel_width', - # '-cnn_kernel_width', - # type=int, - # default=3, - # help="Size of windows in the cnn, the kernel_size is (cnn_kernel_width, 1) in conv layer", - # ) - group.add( '--pos_ffn_activation_fn', '-pos_ffn_activation_fn', @@ -446,76 +325,23 @@ def model_opts(parser): group.add('-normformer', '--normformer', action='store_true', help='NormFormer-style normalization') + # Attention options + group = parser.add_argument_group('Model- Attention') # group.add( - # '--input_feed', - # '-input_feed', - # type=int, - # default=1, - # help="Feed the context vector at each time step as " - # "additional input (via concatenation with the word " - # "embeddings) to the decoder.", + # '--global_attention', + # '-global_attention', + # type=str, + # default='general', + # choices=['dot', 'general', 'mlp', 'none'], + # help="The attention type to use: dotprod or general (Luong) or MLP (Bahdanau)", # ) - group.add( - '--bridge', - '-bridge', - action="store_true", - help="Have an additional layer between the last encoder state and the first decoder state", - ) - # group.add('--residual', '-residual', action="store_true", - # help="Add residual connections between RNN layers.") - # group.add( - # '--context_gate', - # '-context_gate', + # '--global_attention_function', + # '-global_attention_function', # type=str, - # default=None, - # choices=['source', 'target', 'both'], - # help="Type of context gate to use. Do not select for no context gate.", + # default="softmax", + # choices=["softmax"], # ) - - # The following options (bridge_extra_node to n_steps) are used - # for training with --encoder_type ggnn (Gated Graph Neural Network). - group.add( - '--bridge_extra_node', - '-bridge_extra_node', - type=bool, - default=True, - help='Graph encoder bridges only extra node to decoder as input', - ) - group.add( - '--bidir_edges', '-bidir_edges', type=bool, default=True, help='Graph encoder autogenerates bidirectional edges' - ) - group.add( - '--state_dim', '-state_dim', type=int, default=512, help='Number of state dimensions in the graph encoder' - ) - group.add('--n_edge_types', '-n_edge_types', type=int, default=2, help='Number of edge types in the graph encoder') - group.add('--n_node', '-n_node', type=int, default=2, help='Number of nodes in the graph encoder') - group.add('--n_steps', '-n_steps', type=int, default=2, help='Number of steps to advance graph encoder') - group.add( - '--src_ggnn_size', - '-src_ggnn_size', - type=int, - default=0, - help='Vocab size plus feature space for embedding input', - ) - - # Attention options - group = parser.add_argument_group('Model- Attention') - group.add( - '--global_attention', - '-global_attention', - type=str, - default='general', - choices=['dot', 'general', 'mlp', 'none'], - help="The attention type to use: dotprod or general (Luong) or MLP (Bahdanau)", - ) - group.add( - '--global_attention_function', - '-global_attention_function', - type=str, - default="softmax", - choices=["softmax"], - ) group.add( '--self_attn_type', '-self_attn_type', @@ -523,6 +349,8 @@ def model_opts(parser): default="scaled-dot", help='Self attention type in Transformer decoder layer -- currently "scaled-dot" or "average" ', ) + + # TODO is this actually in use? group.add( '--max_relative_positions', '-max_relative_positions', @@ -537,10 +365,12 @@ def model_opts(parser): group.add( '--transformer_ff', '-transformer_ff', type=int, default=2048, help='Size of hidden transformer feed-forward' ) + # TODO is this actually in use? group.add('--aan_useffn', '-aan_useffn', action="store_true", help='Turn on the FFN layer in the AAN decoder') # Alignement options - group = parser.add_argument_group('Model - Alignement') + # TODO is this actually in use? + group = parser.add_argument_group('Model - Alignment') group.add( '--lambda_align', '-lambda_align', @@ -569,6 +399,7 @@ def model_opts(parser): # Generator and loss options. group = parser.add_argument_group('Generator') + # FIXME likely broken, should be removed group.add('--copy_attn', '-copy_attn', action="store_true", help='Train copy attention layer.') group.add( '--copy_attn_type', @@ -657,7 +488,8 @@ def model_opts(parser): def _add_train_general_opts(parser): """General options for training""" group = parser.add_argument_group('General') - group.add('--data_type', '-data_type', default="text", help="Type of the source input. Options are [text].") + # TODO maybe relevant for issue #53 + # group.add('--data_type', '-data_type', default="text", help="Type of the source input. Options are [text].") group.add( '--save_model', @@ -665,12 +497,14 @@ def _add_train_general_opts(parser): default='model', help="Model filename (the model will be saved as _N.pt where N is the number of steps", ) + group.add( "--save_all_gpus", "-save_all_gpus", action="store_true", - help="Whether to store a model from every gpu (in addition to the modules)", + help="Deprecated.", ) + group.add( '--save_checkpoint_steps', '-save_checkpoint_steps', @@ -681,9 +515,25 @@ def _add_train_general_opts(parser): group.add( '--keep_checkpoint', '-keep_checkpoint', type=int, default=-1, help="Keep X checkpoints (negative: keep all)" ) + group.add('--train_steps', '-train_steps', type=int, default=100000, help='Number of training steps') + group.add( + '--single_pass', '-single_pass', action='store_true', help="Make a single pass over the training dataset." + ) + group.add('--epochs', '-epochs', type=int, default=0, help='Deprecated epochs see train_steps') + group.add('--valid_steps', '-valid_steps', type=int, default=10000, help='Perfom validation every X steps') + group.add( + '--early_stopping', '-early_stopping', type=int, default=0, help='Number of validation steps without improving.' + ) + group.add( + '--early_stopping_criteria', + '-early_stopping_criteria', + nargs="*", + default=None, + help='Criteria to use for early stopping.', + ) # GPU - group.add('--gpuid', '-gpuid', default=[], nargs='*', type=int, help="Deprecated see world_size and gpu_ranks.") + group = parser.add_argument_group('Computation Environment') group.add('--gpu_ranks', '-gpu_ranks', default=[], nargs='*', type=int, help="list of ranks of each process.") group.add('--n_nodes', '-n_nodes', default=1, type=int, help="total number of training nodes.") group.add( @@ -695,6 +545,7 @@ def _add_train_general_opts(parser): "When using non-distributed training (CPU, single-GPU), set to 0" ) group.add('--world_size', '-world_size', default=1, type=int, help="total number of distributed processes.") + # TODO is gpu_backend actually in use? group.add('--gpu_backend', '-gpu_backend', default="nccl", type=str, help="Type of torch distributed backend") group.add( '--gpu_verbose_level', @@ -748,21 +599,6 @@ def _add_train_general_opts(parser): help="Optimization resetter when train_from.", ) - # Pretrained word vectors - group.add( - '--pre_word_vecs_enc', - '-pre_word_vecs_enc', - help="If a valid path is specified, then this will load " - "pretrained word embeddings on the encoder side. " - "See README for specific formatting instructions.", - ) - group.add( - '--pre_word_vecs_dec', - '-pre_word_vecs_dec', - help="If a valid path is specified, then this will load " - "pretrained word embeddings on the decoder side. " - "See README for specific formatting instructions.", - ) # Freeze word vectors group.add( '--freeze_word_vecs_enc', @@ -778,8 +614,9 @@ def _add_train_general_opts(parser): ) # Optimization options - group = parser.add_argument_group('Optimization- Type') + group = parser.add_argument_group('Batching') group.add('--batch_size', '-batch_size', type=int, default=64, help='Maximum batch size for training') + group.add('--valid_batch_size', '-valid_batch_size', type=int, default=32, help='Maximum batch size for validation') group.add( '--batch_size_multiple', '-batch_size_multiple', @@ -794,32 +631,6 @@ def _add_train_general_opts(parser): choices=["sents", "tokens"], help="Batch grouping for batch_size. Standard is sents. Tokens will do dynamic batching", ) - group.add( - '--normalization', - '-normalization', - default='sents', - choices=["sents", "tokens"], - help='Normalization method of the gradient.', - ) - group.add( - '--accum_count', - '-accum_count', - type=int, - nargs='+', - default=[1], - help="Accumulate gradient this many times. " - "Approximately equivalent to updating " - "batch_size * accum_count batches at once. " - "Recommended for Transformer.", - ) - group.add( - '--accum_steps', - '-accum_steps', - type=int, - nargs='+', - default=[0], - help="Steps at which accum_count values change", - ) group.add( '--task_distribution_strategy', '-task_distribution_strategy', @@ -827,8 +638,6 @@ def _add_train_general_opts(parser): default='weighted_sampling', help="Strategy for the order in which tasks (e.g. language pairs) are scheduled for training" ) - group.add('--valid_steps', '-valid_steps', type=int, default=10000, help='Perfom validation every X steps') - group.add('--valid_batch_size', '-valid_batch_size', type=int, default=32, help='Maximum batch size for validation') group.add( '--max_generator_batches', '-max_generator_batches', @@ -838,21 +647,22 @@ def _add_train_general_opts(parser): "the generator on in parallel. Higher is faster, but " "uses more memory. Set to 0 to disable.", ) - group.add('--train_steps', '-train_steps', type=int, default=100000, help='Number of training steps') - group.add( - '--single_pass', '-single_pass', action='store_true', help="Make a single pass over the training dataset." - ) - group.add('--epochs', '-epochs', type=int, default=0, help='Deprecated epochs see train_steps') group.add( - '--early_stopping', '-early_stopping', type=int, default=0, help='Number of validation steps without improving.' + "-pool_size", + "--pool_size", + type=int, + default=2048, + help="Number of examples to dynamically pool before batching.", ) group.add( - '--early_stopping_criteria', - '-early_stopping_criteria', - nargs="*", - default=None, - help='Criteria to use for early stopping.', + "-n_buckets", + "--n_buckets", + type=int, + default=1024, + help="When batch_type=tokens, maximum number of bins for batching.", ) + + group = parser.add_argument_group('Optimization') group.add( '--optim', '-optim', @@ -885,6 +695,7 @@ def _add_train_general_opts(parser): default=0, help="L2 penalty (weight decay) regularizer", ) + # FIXME, mentions LSTM group.add( '--dropout', '-dropout', @@ -904,7 +715,6 @@ def _add_train_general_opts(parser): group.add( '--dropout_steps', '-dropout_steps', type=int, nargs='+', default=[0], help="Steps at which dropout changes." ) - group.add('--truncated_decoder', '-truncated_decoder', type=int, default=0, help="""Truncated bptt.""") group.add( '--adam_beta1', '-adam_beta1', @@ -964,9 +774,33 @@ def _add_train_general_opts(parser): default=1, help="Step for moving average. Default is every update, if -average_decay is set.", ) + group.add( + '--normalization', + '-normalization', + default='sents', + choices=["sents", "tokens"], + help='Normalization method of the gradient.', + ) + group.add( + '--accum_count', + '-accum_count', + type=int, + nargs='+', + default=[1], + help="Accumulate gradient this many times. " + "Approximately equivalent to updating " + "batch_size * accum_count batches at once. " + "Recommended for Transformer.", + ) + group.add( + '--accum_steps', + '-accum_steps', + type=int, + nargs='+', + default=[0], + help="Steps at which accum_count values change", + ) - # learning rate - group = parser.add_argument_group('Optimization- Rate') group.add( '--learning_rate', '-learning_rate', @@ -1006,24 +840,6 @@ def _add_train_general_opts(parser): _add_logging_opts(parser, is_train=True) -def _add_train_dynamic_data(parser): - group = parser.add_argument_group("Dynamic data") - group.add( - "-pool_size", - "--pool_size", - type=int, - default=2048, - help="Number of examples to dynamically pool before batching.", - ) - group.add( - "-n_buckets", - "--n_buckets", - type=int, - default=1024, - help="Maximum number of bins for batching.", - ) - - def train_opts(parser): """All options used in train.""" # options relate to data preprare @@ -1031,7 +847,6 @@ def train_opts(parser): # options relate to train model_opts(parser) _add_train_general_opts(parser) - _add_train_dynamic_data(parser) def _add_decoding_opts(parser): diff --git a/mammoth/tests/test_data_prepare.py b/mammoth/tests/test_data_prepare.py index d1780982..7b2db323 100644 --- a/mammoth/tests/test_data_prepare.py +++ b/mammoth/tests/test_data_prepare.py @@ -47,11 +47,6 @@ # prepare_fields_transforms(opts) # except SystemExit as err: # print(err) -# except IOError as err: -# if opts.skip_empty_level != 'error': -# raise err -# else: -# print(f"Catched IOError: {err}") # finally: # # Remove the generated *pt files. # for pt in glob.glob(SAVE_DATA_PREFIX + '*.pt'): @@ -107,12 +102,10 @@ # [('tgt_seq_length_trunc', 1)], # [('tgt_seq_length_trunc', 5000)], # [('copy_attn', True)], -# [('share_vocab', True)], # [('n_sample', 30), # ('save_data', SAVE_DATA_PREFIX)], # [('n_sample', 30), # ('save_data', SAVE_DATA_PREFIX), -# ('skip_empty_level', 'error')] # ] # # for p in test_databuild: diff --git a/mammoth/transforms/tokenize.py b/mammoth/transforms/tokenize.py index 5c6e283a..38dfd437 100644 --- a/mammoth/transforms/tokenize.py +++ b/mammoth/transforms/tokenize.py @@ -98,6 +98,7 @@ def add_options(cls, parser): default=0, help="Only produce tgt subword in tgt_subword_vocab with frequency >= tgt_vocab_threshold.", ) + group.add('-share_vocab', '--share_vocab', action='store_true', help='use the same model for both sides') @classmethod def _validate_options(cls, opts): diff --git a/mammoth/utils/parse.py b/mammoth/utils/parse.py index 53941d9e..aa0cbe6e 100644 --- a/mammoth/utils/parse.py +++ b/mammoth/utils/parse.py @@ -191,16 +191,10 @@ def _validate_fields_opts(cls, opts, build_vocab_only=False): for feature in corpus["src_feats"].keys(): assert feature in opts.src_feats_vocab, f"No vocab file set for feature {feature}" - if build_vocab_only: - if not opts.share_vocab: - assert opts.tgt_vocab, "-tgt_vocab is required if not -share_vocab." - return # validation when train: for key, vocab in opts.src_vocab.items(): cls._validate_file(vocab, info=f'src vocab ({key})') - if not opts.share_vocab: - for key, vocab in opts.tgt_vocab.items(): - cls._validate_file(vocab, info=f'tgt vocab ({key})') + cls._validate_file(vocab, info=f'tgt vocab ({key})') # if opts.dump_fields or opts.dump_transforms: if opts.dump_transforms: @@ -227,7 +221,7 @@ def _validate_language_model_compatibilities_opts(cls, opts): logger.info("encoder is not used for LM task") - assert opts.share_vocab and (opts.tgt_vocab is None), "vocab must be shared for LM task" + assert opts.tgt_vocab is None, "vocab must be shared for LM task" assert opts.decoder_type == "transformer", "Only transformer decoder is supported for LM task" @@ -294,13 +288,10 @@ def update_model_opts(cls, model_opts): if hasattr(model_opts, 'fix_word_vecs_dec'): model_opts.freeze_word_vecs_dec = model_opts.fix_word_vecs_dec - if model_opts.layers > 0: - raise Exception('--layers is deprecated') - - model_opts.brnn = model_opts.encoder_type == "brnn" - - if model_opts.copy_attn_type is None: - model_opts.copy_attn_type = model_opts.global_attention + # model_opts.brnn = model_opts.encoder_type == "brnn" + # + # if model_opts.copy_attn_type is None: + # model_opts.copy_attn_type = model_opts.global_attention if model_opts.alignment_layer is None: model_opts.alignment_layer = -2 @@ -379,4 +370,4 @@ def validate_translate_opts(cls, opts): def validate_translate_opts_dynamic(cls, opts): # It comes from training # TODO: needs to be added as inference opts - opts.share_vocab = False + pass From 7ae66d6542f20f7b83b7434dc72819e26fda74d9 Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Fri, 8 Mar 2024 11:35:23 +0200 Subject: [PATCH 2/7] BPTT / semi-plugged removed --- mammoth/opts.py | 19 ++----------------- mammoth/tests/test_models.py | 4 ---- mammoth/trainer.py | 23 +++-------------------- mammoth/utils/loss.py | 17 ++++++----------- mammoth/utils/parse.py | 7 ------- 5 files changed, 11 insertions(+), 59 deletions(-) diff --git a/mammoth/opts.py b/mammoth/opts.py index 3e26a61e..3ea30971 100644 --- a/mammoth/opts.py +++ b/mammoth/opts.py @@ -301,8 +301,8 @@ def model_opts(parser): ) # group.add('--layers', '-layers', type=int, default=-1, help='Deprecated') - group.add('--enc_layers', '-enc_layers', nargs='+', type=int, help='Number of layers in each encoder') - group.add('--dec_layers', '-dec_layers', nargs='+', type=int, help='Number of layers in each decoder') + group.add('--enc_layers', '-enc_layers', nargs='+', type=int, help='Number of layers in each encoder module') + group.add('--dec_layers', '-dec_layers', nargs='+', type=int, help='Number of layers in each decoder module') group.add( '--model_dim', '-model_dim', @@ -327,21 +327,6 @@ def model_opts(parser): # Attention options group = parser.add_argument_group('Model- Attention') - # group.add( - # '--global_attention', - # '-global_attention', - # type=str, - # default='general', - # choices=['dot', 'general', 'mlp', 'none'], - # help="The attention type to use: dotprod or general (Luong) or MLP (Bahdanau)", - # ) - # group.add( - # '--global_attention_function', - # '-global_attention_function', - # type=str, - # default="softmax", - # choices=["softmax"], - # ) group.add( '--self_attn_type', '-self_attn_type', diff --git a/mammoth/tests/test_models.py b/mammoth/tests/test_models.py index aea9355f..f74d54a0 100644 --- a/mammoth/tests/test_models.py +++ b/mammoth/tests/test_models.py @@ -180,16 +180,12 @@ def test_method(self): ], # [('coverage_attn', True)], # [('copy_attn', True)], - # [('global_attention', 'mlp')], # [('context_gate', 'both')], # [('context_gate', 'target')], # [('context_gate', 'source')], # [('encoder_type', "brnn"), ('brnn_merge', 'sum')], # [('encoder_type', "brnn")], # [('decoder_type', 'cnn'), ('encoder_type', 'cnn')], - # [('encoder_type', 'rnn'), ('global_attention', None)], - # [('encoder_type', 'rnn'), ('global_attention', None), ('copy_attn', True), ('copy_attn_type', 'general')], - # [('encoder_type', 'rnn'), ('global_attention', 'mlp'), ('copy_attn', True), ('copy_attn_type', 'general')], # [], ] diff --git a/mammoth/trainer.py b/mammoth/trainer.py index d5edf246..480ff9e3 100644 --- a/mammoth/trainer.py +++ b/mammoth/trainer.py @@ -68,7 +68,6 @@ def build_trainer( mammoth.utils.loss.build_loss_compute(model, tgt_vocab, opts, train=False, generator=generator), ) - trunc_size = opts.truncated_decoder # Badly named... shard_size = opts.max_generator_batches if opts.model_dtype == 'fp32' else 0 norm_method = opts.normalization accum_count = opts.accum_count @@ -91,7 +90,6 @@ def build_trainer( train_loss_md, valid_loss_md, optim, - trunc_size, shard_size, norm_method, accum_count, @@ -126,7 +124,6 @@ class Trainer(object): training loss computation optim(:obj:`mammoth.utils.optimizers.Optimizer`): the optimizer responsible for update - trunc_size(int): length of truncated back propagation through time shard_size(int): compute loss in shards of this size for efficiency data_type(string): type of the source input: [text] norm_method(string): normalization methods: [sents|tokens] @@ -145,7 +142,6 @@ def __init__( train_loss_md, valid_loss_md, optim, - trunc_size=0, shard_size=32, norm_method="sents", accum_count=[1], @@ -169,7 +165,6 @@ def __init__( self.train_loss_md = train_loss_md self.valid_loss_md = valid_loss_md self.optim = optim - self.trunc_size = trunc_size self.shard_size = shard_size self.norm_method = norm_method self.accum_count_l = accum_count @@ -200,11 +195,6 @@ def __init__( for i in range(len(self.accum_count_l)): assert self.accum_count_l[i] > 0 - if self.accum_count_l[i] > 1: - assert ( - self.trunc_size == 0 - ), """To enable accumulated gradients, - you must disable target sequence truncating.""" # Set model in training mode. self.model.train() @@ -476,12 +466,6 @@ def _gradient_accumulation_over_lang_pairs( # logger.info(f'batch with metadata {metadata}') target_size = batch.tgt.size(0) - # Truncated BPTT: reminder not compatible with accum > 1 - if self.trunc_size: - raise Exception('Truncated BPTT not supported') - trunc_size = self.trunc_size - else: - trunc_size = target_size src, src_lengths = batch.src if isinstance(batch.src, tuple) else (batch.src, None) if src_lengths is not None: @@ -493,9 +477,10 @@ def _gradient_accumulation_over_lang_pairs( tgt_outer = batch.tgt bptt = False - for j in range(0, target_size - 1, trunc_size): + # TODO: these loops come from truncation / BPTT implementations which are removed in #60 + for j in range(0, target_size - 1, target_size): # 1. Create truncated target. - tgt = tgt_outer[j:(j + trunc_size)] + tgt = tgt_outer[j:(j + target_size)] # TODO: AMP == TRUE If fp16 with torch.cuda.amp.autocast(enabled=self.optim.amp): outputs, attns = self.model( @@ -510,8 +495,6 @@ def _gradient_accumulation_over_lang_pairs( attns, normalization=normalization, shard_size=self.shard_size, - trunc_start=j, - trunc_size=trunc_size, ) # logger.info(loss) diff --git a/mammoth/utils/loss.py b/mammoth/utils/loss.py index 5061325d..9ac5088c 100644 --- a/mammoth/utils/loss.py +++ b/mammoth/utils/loss.py @@ -133,15 +133,11 @@ def _compute_loss(self, batch, output, target, **kwargs): """ return NotImplementedError - def __call__(self, batch, output, attns, normalization=1.0, shard_size=0, trunc_start=0, trunc_size=None): + def __call__(self, batch, output, attns, normalization=1.0, shard_size=0): """Compute the forward loss, possibly in shards in which case this method also runs the backward pass and returns ``None`` as the loss value. - Also supports truncated BPTT for long sequences by taking a - range in the decoder output sequence to back propagate in. - Range is from `(trunc_start, trunc_start + trunc_size)`. - Note sharding is an exact efficiency trick to relieve memory required for the generation buffers. Truncation is an approximate efficiency trick to relieve the memory required @@ -155,16 +151,15 @@ def __call__(self, batch, output, attns, normalization=1.0, shard_size=0, trunc_ `[tgt_len x batch x src_len]` normalization: Optional normalization factor. shard_size (int) : maximum number of examples in a shard - trunc_start (int) : starting position of truncation window - trunc_size (int) : length of truncation window Returns: A tuple with the loss and a :obj:`mammoth.utils.Statistics` instance. """ - if trunc_size is None: - trunc_size = batch.tgt.size(0) - trunc_start - trunc_range = (trunc_start, trunc_start + trunc_size) - shard_state = self._make_shard_state(batch, output, trunc_range, attns) + # TODO: keeping a range_ for now, but should be removed, ideally. + # Inherited from BPTT implementations + size = batch.tgt.size(0) + range_ = (0, size) + shard_state = self._make_shard_state(batch, output, range_, attns) if shard_size == 0: loss, stats = self._compute_loss(batch, **shard_state) return loss / float(normalization), stats diff --git a/mammoth/utils/parse.py b/mammoth/utils/parse.py index aa0cbe6e..20301ca5 100644 --- a/mammoth/utils/parse.py +++ b/mammoth/utils/parse.py @@ -288,11 +288,6 @@ def update_model_opts(cls, model_opts): if hasattr(model_opts, 'fix_word_vecs_dec'): model_opts.freeze_word_vecs_dec = model_opts.fix_word_vecs_dec - # model_opts.brnn = model_opts.encoder_type == "brnn" - # - # if model_opts.copy_attn_type is None: - # model_opts.copy_attn_type = model_opts.global_attention - if model_opts.alignment_layer is None: model_opts.alignment_layer = -2 model_opts.lambda_align = 0.0 @@ -334,8 +329,6 @@ def ckpt_model_opts(cls, ckpt_opt): def validate_train_opts(cls, opts): if opts.epochs: raise AssertionError("-epochs is deprecated please use -train_steps.") - if opts.truncated_decoder > 0 and max(opts.accum_count) > 1: - raise AssertionError("BPTT is not compatible with -accum > 1") if opts.gpuid: raise AssertionError("gpuid is deprecated see world_size and gpu_ranks") From b3b6292068694ec81b0b2219e27a96b0d2eab927 Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Fri, 8 Mar 2024 11:41:19 +0200 Subject: [PATCH 3/7] disambiguate filter arg groups --- mammoth/transforms/filtering.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mammoth/transforms/filtering.py b/mammoth/transforms/filtering.py index 01831681..72ec3d6e 100644 --- a/mammoth/transforms/filtering.py +++ b/mammoth/transforms/filtering.py @@ -29,7 +29,7 @@ def __init__(self, opts): @classmethod def add_options(cls, parser): """Available options relating to this Transform.""" - group = parser.add_argument_group("Transform/Filter") + group = parser.add_argument_group("Transform/Length filter") group.add("--src_seq_length", "-src_seq_length", type=int, default=200, help="Maximum source sequence length.") group.add("--tgt_seq_length", "-tgt_seq_length", type=int, default=200, help="Maximum target sequence length.") @@ -70,7 +70,7 @@ def __init__(self, opts): @classmethod def add_options(cls, parser): """Available options relating to this Transform.""" - group = parser.add_argument_group("Transform/Filter") + group = parser.add_argument_group("Transform/Word ratio filter") group.add("--word_ratio_threshold", "-word_ratio_threshold", type=int, default=3, help="Threshold for discarding sentences based on word ratio.") @@ -106,7 +106,7 @@ def __init__(self, opts): @classmethod def add_options(cls, parser): """Available options relating to this Transform.""" - group = parser.add_argument_group("Transform/Filter") + group = parser.add_argument_group("Transform/Repetitions filter") group.add("--rep_threshold", "-rep_threshold", type=int, default=2, help="Number of times the substring is repeated.") group.add("--rep_min_len", "-rep_min_len", type=int, default=3, @@ -155,7 +155,7 @@ def __init__(self, opts): @classmethod def add_options(cls, parser): """Available options relating to this Transform.""" - group = parser.add_argument_group("Transform/Filter") + group = parser.add_argument_group("Transform/Terminal punctuation filter") group.add("--punct_threshold", "-punct_threshold", type=int, default=-2, help="Minimum penalty score for discarding sentences based on their terminal punctuation signs") @@ -194,7 +194,7 @@ def __init__(self, opts): @classmethod def add_options(cls, parser): """Available options relating to this Transform.""" - group = parser.add_argument_group("Transform/Filter") + group = parser.add_argument_group("Transform/Non-zero numerals filter") group.add("--nonzero_threshold", "-nonzero_threshold", type=float, default=0.5, help="Threshold for discarding sentences based on numerals between the segments with zeros removed") From c4f22d6cc8eb47c886e1d0112fa42887cb560ab0 Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Tue, 12 Mar 2024 11:50:47 +0200 Subject: [PATCH 4/7] debugging --- mammoth/inputters/dataloader.py | 3 --- mammoth/model_builder.py | 7 +------ mammoth/utils/loss.py | 5 +++-- mammoth/utils/parse.py | 29 ++++++++--------------------- 4 files changed, 12 insertions(+), 32 deletions(-) diff --git a/mammoth/inputters/dataloader.py b/mammoth/inputters/dataloader.py index ac9ec452..df6204a7 100644 --- a/mammoth/inputters/dataloader.py +++ b/mammoth/inputters/dataloader.py @@ -215,7 +215,6 @@ class DynamicDatasetIter(object): batch_type (str): batching type to count on, choices=[tokens, sents]; batch_size (int): numbers of examples in a batch; batch_size_multiple (int): make batch size multiply of this; - data_type (str): input data type, currently only text; pool_size (int): accum this number of examples in a dynamic dataset; skip_empty_level (str): security level when encouter empty line; stride (int): iterate data files with this stride; @@ -237,7 +236,6 @@ def __init__( batch_type, batch_size, batch_size_multiple, - data_type="text", pool_size=2048, n_buckets=1024, ): @@ -274,7 +272,6 @@ def from_opts(cls, task_queue_manager, transforms_cls, vocabs_dict, opts, is_tra opts.batch_type, batch_size, batch_size_multiple, - data_type=opts.data_type, pool_size=opts.pool_size, n_buckets=opts.n_buckets, ) diff --git a/mammoth/model_builder.py b/mammoth/model_builder.py index 4bb1c002..36c9e0a3 100644 --- a/mammoth/model_builder.py +++ b/mammoth/model_builder.py @@ -263,10 +263,7 @@ def create_bilingual_model( def build_src_emb(model_opts, src_vocab): # Build embeddings. - if model_opts.model_type == "text": - src_emb = build_embeddings(model_opts, src_vocab) - else: - src_emb = None + src_emb = build_embeddings(model_opts, src_vocab) return src_emb @@ -288,8 +285,6 @@ def build_task_specific_model( checkpoint, ): logger.info(f'TaskQueueManager: {task_queue_manager}') - if not model_opts.model_task == ModelTask.SEQ2SEQ: - raise ValueError(f"Only ModelTask.SEQ2SEQ works - {model_opts.model_task} task") src_embs = dict() tgt_embs = dict() diff --git a/mammoth/utils/loss.py b/mammoth/utils/loss.py index 9ac5088c..dda9521b 100644 --- a/mammoth/utils/loss.py +++ b/mammoth/utils/loss.py @@ -56,14 +56,15 @@ def build_loss_compute(model, tgt_vocab, opts, train=True, generator=None): else: raise ValueError(f"No copy generator loss defined for task {opts.model_task}") else: - if opts.model_task == ModelTask.SEQ2SEQ: + # TODO: keeping this in light of possible encoder-only / decoder-only support + if True: # opts.model_task == ModelTask.SEQ2SEQ: compute = NMTLossCompute( criterion, loss_gen, lambda_coverage=opts.lambda_coverage, lambda_align=opts.lambda_align, ) - elif opts.model_task == ModelTask.LANGUAGE_MODEL: + elif False: # elif opts.model_task == ModelTask.LANGUAGE_MODEL: assert opts.lambda_align == 0.0, "lamdba_align not supported in LM loss" compute = LMLossCompute( criterion, diff --git a/mammoth/utils/parse.py b/mammoth/utils/parse.py index 20301ca5..52b119b9 100644 --- a/mammoth/utils/parse.py +++ b/mammoth/utils/parse.py @@ -201,18 +201,6 @@ def _validate_fields_opts(cls, opts, build_vocab_only=False): assert ( opts.save_data ), "-save_data should be set if set -dump_transforms." - # Check embeddings stuff - if opts.both_embeddings is not None: - assert ( - opts.src_embeddings is None and opts.tgt_embeddings is None - ), "You don't need -src_embeddings or -tgt_embeddings \ - if -both_embeddings is set." - - if any([opts.both_embeddings is not None, opts.src_embeddings is not None, opts.tgt_embeddings is not None]): - assert opts.embeddings_type is not None, "You need to specify an -embedding_type!" - assert ( - opts.save_data - ), "-save_data should be set if use pretrained embeddings." @classmethod def _validate_language_model_compatibilities_opts(cls, opts): @@ -295,14 +283,14 @@ def update_model_opts(cls, model_opts): @classmethod def validate_model_opts(cls, model_opts): - assert model_opts.model_type in ["text"], "Unsupported model type %s" % model_opts.model_type + # assert model_opts.model_type in ["text"], "Unsupported model type %s" % model_opts.model_type # encoder and decoder should be same sizes # assert same_size, "The encoder and decoder rnns must be the same size for now" - if model_opts.share_embeddings: - if model_opts.model_type != "text": - raise AssertionError("--share_embeddings requires --model_type text.") + # if model_opts.share_embeddings: + # if model_opts.model_type != "text": + # raise AssertionError("--share_embeddings requires --model_type text.") if model_opts.lambda_align > 0.0: assert model_opts.decoder_type == 'transformer', "Only transformer is supported to joint learn alignment." assert ( @@ -330,8 +318,6 @@ def validate_train_opts(cls, opts): if opts.epochs: raise AssertionError("-epochs is deprecated please use -train_steps.") - if opts.gpuid: - raise AssertionError("gpuid is deprecated see world_size and gpu_ranks") if torch.cuda.is_available() and not opts.gpu_ranks: logger.warn("You have a CUDA device, should run with -gpu_ranks") if opts.world_size < len(opts.gpu_ranks): @@ -351,9 +337,10 @@ def validate_train_opts(cls, opts): opts.accum_steps ), 'Number of accum_count values must match number of accum_steps' - if opts.update_vocab: - assert opts.train_from, "-update_vocab needs -train_from option" - assert opts.reset_optim in ['states', 'all'], '-update_vocab needs -reset_optim "states" or "all"' + # TODO: do we want to remove that completely? + # if opts.update_vocab: + # assert opts.train_from, "-update_vocab needs -train_from option" + # assert opts.reset_optim in ['states', 'all'], '-update_vocab needs -reset_optim "states" or "all"' @classmethod def validate_translate_opts(cls, opts): From d67a5702319e9a9b5ce81a671ecf62e2fcc2447a Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Tue, 12 Mar 2024 11:54:39 +0200 Subject: [PATCH 5/7] removed unused import --- mammoth/model_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mammoth/model_builder.py b/mammoth/model_builder.py index 36c9e0a3..fb1ad9eb 100644 --- a/mammoth/model_builder.py +++ b/mammoth/model_builder.py @@ -16,7 +16,7 @@ EncoderAdapterLayer, DecoderAdapterLayer, ) -from mammoth.constants import ModelTask, DefaultTokens +from mammoth.constants import DefaultTokens from mammoth.modules.layer_stack_decoder import LayerStackDecoder from mammoth.modules.layer_stack_encoder import LayerStackEncoder from mammoth.modules import Embeddings From a57f0630efefa19ddfaf82efaa51f72f097cf088 Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Fri, 17 May 2024 08:43:22 +0300 Subject: [PATCH 6/7] linting --- mammoth/model_builder.py | 2 +- mammoth/opts.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/mammoth/model_builder.py b/mammoth/model_builder.py index 26ec9cd8..05de3e86 100644 --- a/mammoth/model_builder.py +++ b/mammoth/model_builder.py @@ -284,7 +284,7 @@ def build_task_specific_model( task_queue_manager, checkpoint, ): - + src_embs = dict() tgt_embs = dict() diff --git a/mammoth/opts.py b/mammoth/opts.py index 1eab96b7..9f672152 100644 --- a/mammoth/opts.py +++ b/mammoth/opts.py @@ -649,7 +649,6 @@ def _add_train_general_opts(parser): "Recommended value: same as accum_count, or at least a multiple of it." ) - group = parser.add_argument_group('Optimization') group.add( '--optim', From 9a4e0f4852704cd335e2abb8a61bfedf496fb0f7 Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Fri, 17 May 2024 08:48:18 +0300 Subject: [PATCH 7/7] drop some more --- mammoth/opts.py | 13 +++---------- mammoth/train_single.py | 4 ---- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/mammoth/opts.py b/mammoth/opts.py index 9f672152..bd297b40 100644 --- a/mammoth/opts.py +++ b/mammoth/opts.py @@ -285,10 +285,8 @@ def model_opts(parser): '-encoder_type', type=str, default='transformer', - choices=['mean', 'transformer'], # TODO is this mean actually supported? - help="Type of encoder layer to use. Non-RNN layers " - "are experimental. Options are " - "[mean|transformer].", + choices=['transformer'], + help="In deprecation. Only transformers are supported." ) group.add( '--decoder_type', @@ -296,9 +294,7 @@ def model_opts(parser): type=str, default='transformer', choices=['transformer'], - help="Type of decoder layer to use. Non-RNN layers " - "are experimental. Options are " - "[transformer].", + help="In deprecation. Only transformers are supported." ) # group.add('--layers', '-layers', type=int, default=-1, help='Deprecated') @@ -502,9 +498,6 @@ def _add_train_general_opts(parser): '--keep_checkpoint', '-keep_checkpoint', type=int, default=-1, help="Keep X checkpoints (negative: keep all)" ) group.add('--train_steps', '-train_steps', type=int, default=100000, help='Number of training steps') - group.add( - '--single_pass', '-single_pass', action='store_true', help="Make a single pass over the training dataset." - ) group.add('--epochs', '-epochs', type=int, default=0, help='Deprecated epochs see train_steps') group.add('--valid_steps', '-valid_steps', type=int, default=10000, help='Perfom validation every X steps') group.add( diff --git a/mammoth/train_single.py b/mammoth/train_single.py index 72abab8c..aedde26e 100644 --- a/mammoth/train_single.py +++ b/mammoth/train_single.py @@ -161,10 +161,6 @@ def _train_iter(): else: logger.info('Starting training on CPU, could be very slow') train_steps = opts.train_steps - if opts.single_pass and train_steps > 0: - if device_context.is_master(): - logger.warning("Option single_pass is enabled, ignoring train_steps.") - train_steps = 0 logger.info("{} - Starting training".format(device_context.id)) trainer.train( train_iter,