Skip to content

Commit

Permalink
Provide multiple source vocabularies as argument (#530)
Browse files Browse the repository at this point in the history
* provide multiple source vocabularies as argument

* update changelog.md

* update sockeye/__init__.py

* fix source vocabulary check (continue training)

* use regular_file() function for argparse type (source factor vocabs)

* add new source_vocab format to test_arguments

* change sockeye/__init__.py

* Update __init.py__

* add --source-factor-vocabs

* update test_arguments.py
  • Loading branch information
franckbrl authored and fhieber committed Sep 11, 2018
1 parent d182038 commit 985c97e
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 8 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.18.54]
### Added
- `--source-factor-vocabs` can be set to provide source factor vocabularies.

## [1.18.53]
### Added
- Always skipping softmax for greedy decoding by default, only for single models.
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.18.53'
__version__ = '1.18.54'
6 changes: 6 additions & 0 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,12 @@ def add_vocab_args(params):
required=False,
default=None,
help='Existing target vocabulary (JSON).')
params.add_argument('--source-factor-vocabs',
required=False,
nargs='+',
type=regular_file(),
default=[],
help='Existing source factor vocabulary (-ies) (JSON).')
params.add_argument(C.VOCAB_ARG_SHARED_VOCAB,
action='store_true',
default=False,
Expand Down
4 changes: 2 additions & 2 deletions sockeye/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def prepare_data(args: argparse.Namespace):
bucket_width = args.bucket_width

source_paths = [args.source] + args.source_factors
# NOTE: Pre-existing source factor vocabularies not yet supported for prepare data
source_factor_vocab_paths = [None] * len(args.source_factors)
source_factor_vocab_paths = [args.source_factor_vocabs[i] if i < len(args.source_factor_vocabs)
else None for i in range(len(args.source_factors))]
source_vocab_paths = [args.source_vocab] + source_factor_vocab_paths

num_words_source, num_words_target = args.num_words
Expand Down
4 changes: 3 additions & 1 deletion sockeye/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,9 @@ def create_data_iters_and_vocabs(args: argparse.Namespace,

else:
# Load or create vocabs
source_vocab_paths = [args.source_vocab] + [None] * len(args.source_factors)
source_factor_vocab_paths = [args.source_factor_vocabs[i] if i < len(args.source_factor_vocabs)
else None for i in range(len(args.source_factors))]
source_vocab_paths = [args.source_vocab] + source_factor_vocab_paths
target_vocab_path = args.target_vocab
source_vocabs, target_vocab = vocab.load_or_create_vocabs(
source_paths=[args.source] + args.source_factors,
Expand Down
10 changes: 6 additions & 4 deletions test/unit/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
validation_source='test_validation_src', validation_target='test_validation_tgt',
validation_source_factors=[],
output='test_output', overwrite_output=False,
source_vocab=None, target_vocab=None, shared_vocab=False, num_words=(0, 0), word_min_count=(1, 1),
pad_vocab_to_multiple_of=None,
source_vocab=None, target_vocab=None, source_factor_vocabs=[], shared_vocab=False, num_words=(0, 0),
word_min_count=(1, 1), pad_vocab_to_multiple_of=None,
no_bucketing=False, bucket_width=10, max_seq_len=(99, 99),
monitor_pattern=None, monitor_stat_func='mx_default')),
Expand All @@ -50,8 +50,8 @@
validation_source='test_validation_src', validation_target='test_validation_tgt',
validation_source_factors=[],
output='test_output', overwrite_output=False,
source_vocab=None, target_vocab=None, shared_vocab=False, num_words=(0, 0), word_min_count=(1, 1),
pad_vocab_to_multiple_of=None,
source_vocab=None, target_vocab=None, source_factor_vocabs=[], shared_vocab=False, num_words=(0, 0),
word_min_count=(1, 1), pad_vocab_to_multiple_of=None,
no_bucketing=False, bucket_width=10, max_seq_len=(99, 99),
monitor_pattern=None, monitor_stat_func='mx_default'))
])
Expand Down Expand Up @@ -337,6 +337,7 @@ def test_tutorial_averaging_args(test_params, expected_params, expected_params_p
source_vocab=None,
target_vocab=None,
source_factors=[],
source_factor_vocabs=[],
shared_vocab=False,
num_words=(0, 0),
word_min_count=(1, 1),
Expand All @@ -360,6 +361,7 @@ def test_tutorial_prepare_data_cli_args(test_params, expected_params):
source_vocab=None,
target_vocab=None,
source_factors=[],
source_factor_vocabs=[],
shared_vocab=False,
num_words=(0, 0),
word_min_count=(1, 1),
Expand Down

0 comments on commit 985c97e

Please sign in to comment.