diff --git a/CHANGELOG.md b/CHANGELOG.md index 97de0b4ec..933132132 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,11 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [3.1.33] + +### Fixed +- Two small fixes to SampleK. Before the device was not set correctly leading to issues when running sampling on GPUs. Furthermore, SampleK did not return the top-k values correctly. + ## [3.1.32] ### Added diff --git a/sockeye/__init__.py b/sockeye/__init__.py index a71336bcc..c61c69539 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '3.1.32' +__version__ = '3.1.33' diff --git a/sockeye/beam_search.py b/sockeye/beam_search.py index 6d2bfd77d..34e65e4f2 100644 --- a/sockeye/beam_search.py +++ b/sockeye/beam_search.py @@ -475,9 +475,9 @@ def forward(self, scores, target_dists, finished): # n == 0 means sample from the full vocabulary. Otherwise, we sample from the top n. if self.n != 0: # select the top n in each row, via a mask - _, indices = pt.topk(target_dists, k=self.n, dim=1, largest=True, sorted=True) + values, indices = pt.topk(target_dists, k=self.n, dim=1, largest=True, sorted=True) # set items not chosen by topk to 0 - target_dists = pt.scatter(pt.zeros_like(target_dists), 1, indices, target_dists) + target_dists = pt.scatter(pt.zeros_like(target_dists), 1, indices, values) # renormalize target_dists = target_dists / target_dists.sum(1, keepdim=True) @@ -489,7 +489,7 @@ def forward(self, scores, target_dists, finished): # (batch, 1) values = scores.gather(dim=1, index=best_word_indices.long().unsqueeze(1)) # (batch,) - best_hyp_indices = pt.arange(0, best_word_indices.size()[0]) + best_hyp_indices = pt.arange(0, best_word_indices.size()[0], device=best_word_indices.device) return best_hyp_indices, best_word_indices, values