diff --git a/CHANGELOG.md b/CHANGELOG.md index b8968594f..9333033ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [1.18.54] +### Added +- `--source-factor-vocabs` can be set to provide source factor vocabularies. + ## [1.18.53] ### Added - Always skipping softmax for greedy decoding by default, only for single models. diff --git a/sockeye/__init__.py b/sockeye/__init__.py index ded80be1e..81ba22787 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '1.18.53' +__version__ = '1.18.54' diff --git a/sockeye/arguments.py b/sockeye/arguments.py index 63ac61aa4..c457d7837 100644 --- a/sockeye/arguments.py +++ b/sockeye/arguments.py @@ -489,6 +489,12 @@ def add_vocab_args(params): required=False, default=None, help='Existing target vocabulary (JSON).') + params.add_argument('--source-factor-vocabs', + required=False, + nargs='+', + type=regular_file(), + default=[], + help='Existing source factor vocabulary (-ies) (JSON).') params.add_argument(C.VOCAB_ARG_SHARED_VOCAB, action='store_true', default=False, diff --git a/sockeye/prepare_data.py b/sockeye/prepare_data.py index 039b84b8a..e8f6a81d7 100644 --- a/sockeye/prepare_data.py +++ b/sockeye/prepare_data.py @@ -46,8 +46,8 @@ def prepare_data(args: argparse.Namespace): bucket_width = args.bucket_width source_paths = [args.source] + args.source_factors - # NOTE: Pre-existing source factor vocabularies not yet supported for prepare data - source_factor_vocab_paths = [None] * len(args.source_factors) + source_factor_vocab_paths = [args.source_factor_vocabs[i] if i < len(args.source_factor_vocabs) + else None for i in range(len(args.source_factors))] source_vocab_paths = [args.source_vocab] + source_factor_vocab_paths num_words_source, num_words_target = args.num_words diff --git a/sockeye/train.py b/sockeye/train.py index 4c7c06b80..49bd06d4d 100644 --- a/sockeye/train.py +++ b/sockeye/train.py @@ -301,7 +301,9 @@ def create_data_iters_and_vocabs(args: argparse.Namespace, else: # Load or create vocabs - source_vocab_paths = [args.source_vocab] + [None] * len(args.source_factors) + source_factor_vocab_paths = [args.source_factor_vocabs[i] if i < len(args.source_factor_vocabs) + else None for i in range(len(args.source_factors))] + source_vocab_paths = [args.source_vocab] + source_factor_vocab_paths target_vocab_path = args.target_vocab source_vocabs, target_vocab = vocab.load_or_create_vocabs( source_paths=[args.source] + args.source_factors, diff --git a/test/unit/test_arguments.py b/test/unit/test_arguments.py index afaeb652d..f5fd212a8 100644 --- a/test/unit/test_arguments.py +++ b/test/unit/test_arguments.py @@ -35,8 +35,8 @@ validation_source='test_validation_src', validation_target='test_validation_tgt', validation_source_factors=[], output='test_output', overwrite_output=False, - source_vocab=None, target_vocab=None, shared_vocab=False, num_words=(0, 0), word_min_count=(1, 1), - pad_vocab_to_multiple_of=None, + source_vocab=None, target_vocab=None, source_factor_vocabs=[], shared_vocab=False, num_words=(0, 0), + word_min_count=(1, 1), pad_vocab_to_multiple_of=None, no_bucketing=False, bucket_width=10, max_seq_len=(99, 99), monitor_pattern=None, monitor_stat_func='mx_default')), @@ -50,8 +50,8 @@ validation_source='test_validation_src', validation_target='test_validation_tgt', validation_source_factors=[], output='test_output', overwrite_output=False, - source_vocab=None, target_vocab=None, shared_vocab=False, num_words=(0, 0), word_min_count=(1, 1), - pad_vocab_to_multiple_of=None, + source_vocab=None, target_vocab=None, source_factor_vocabs=[], shared_vocab=False, num_words=(0, 0), + word_min_count=(1, 1), pad_vocab_to_multiple_of=None, no_bucketing=False, bucket_width=10, max_seq_len=(99, 99), monitor_pattern=None, monitor_stat_func='mx_default')) ]) @@ -337,6 +337,7 @@ def test_tutorial_averaging_args(test_params, expected_params, expected_params_p source_vocab=None, target_vocab=None, source_factors=[], + source_factor_vocabs=[], shared_vocab=False, num_words=(0, 0), word_min_count=(1, 1), @@ -360,6 +361,7 @@ def test_tutorial_prepare_data_cli_args(test_params, expected_params): source_vocab=None, target_vocab=None, source_factors=[], + source_factor_vocabs=[], shared_vocab=False, num_words=(0, 0), word_min_count=(1, 1),