Skip to content

Commit

Permalink
Fix: Word based batching memory with longer source sentences. (#430)
Browse files Browse the repository at this point in the history
* Fix: Word based batching memory with longer source sentences.

* comment clarification.

* typos
  • Loading branch information
tdomhan authored Jun 11, 2018
1 parent cd1161a commit 83468d2
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 7 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.18.22]
### Fixed
- Make sure the default bucket is large enough with word based batching when the source is longer than the target (Previously
there was an edge case where the memory usage was sub-optimal with word based batching and longer source than target sentences).

## [1.18.21]
### Fixed
- Constrained decoding was missed a crucial cast
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.18.21'
__version__ = '1.18.22'
7 changes: 4 additions & 3 deletions sockeye/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,11 @@ def define_bucket_batch_sizes(buckets: List[Tuple[int, int]],
batch_size_seq = batch_size
batch_size_word = batch_size_seq * average_seq_len
bucket_batch_sizes.append(BucketBatchSize(bucket, batch_size_seq, batch_size_word))
# Track largest number of target word samples in a batch
largest_total_num_words = max(largest_total_num_words, batch_size_seq * padded_seq_len)
# Track largest number of source or target word samples in a batch
largest_total_num_words = max(largest_total_num_words, batch_size_seq * max(*bucket))

# Final step: guarantee that largest bucket by sequence length also has largest total batch size.
# Final step: guarantee that largest bucket by sequence length also has a batch size so that it covers any
# (batch_size, len_source) and (batch_size, len_target) matrix from the data iterator to allow for memory sharing.
# When batching by sentences, this will already be the case.
if batch_by_words:
padded_seq_len = max(*buckets[-1])
Expand Down
16 changes: 13 additions & 3 deletions test/unit/test_data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,23 +223,33 @@ def test_sample_based_define_bucket_batch_sizes():
assert bbs.average_words_per_batch == bbs.bucket[1] * batch_size


def test_word_based_define_bucket_batch_sizes():
@pytest.mark.parametrize("length_ratio", [0.5, 1.5])
def test_word_based_define_bucket_batch_sizes(length_ratio):
batch_by_words = True
batch_num_devices = 1
batch_size = 200
max_seq_len = 100
buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10, 1.5)
buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10, length_ratio)
bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets=buckets,
batch_size=batch_size,
batch_by_words=batch_by_words,
batch_num_devices=batch_num_devices,
data_target_average_len=[None] * len(buckets))
max_num_words = 0
# last bucket batch size is different
for bbs in bucket_batch_sizes[:-1]:
expected_batch_size = round((batch_size / bbs.bucket[1]) / batch_num_devices)
target_padded_seq_len = bbs.bucket[1]
expected_batch_size = round((batch_size / target_padded_seq_len) / batch_num_devices)
assert bbs.batch_size == expected_batch_size
expected_average_words_per_batch = expected_batch_size * bbs.bucket[1]
assert bbs.average_words_per_batch == expected_average_words_per_batch
max_num_words = max(max_num_words, bbs.batch_size * max(*bbs.bucket))

last_bbs = bucket_batch_sizes[-1]
min_expected_batch_size = round((batch_size / last_bbs.bucket[1]) / batch_num_devices)
assert last_bbs.batch_size >= min_expected_batch_size
last_bbs_num_words = last_bbs.batch_size * max(*last_bbs.bucket)
assert last_bbs_num_words >= max_num_words


def _get_random_bucketed_data(buckets: List[Tuple[int, int]],
Expand Down

0 comments on commit 83468d2

Please sign in to comment.