Skip to content

Commit

Permalink
fix sample
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Dec 12, 2024
1 parent b1e907c commit befcc13
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions python/text_utils/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,16 +215,18 @@ def _update_fn(beam: Beam) -> Beam:

def sample() -> SampleFn:
def _sample(logits: torch.Tensor, k: int) -> torch.Tensor:
assert logits.ndim == 1, "expected logits to be 1D"
k = min(k, logits.shape[-1] - torch.isneginf(logits).sum().item())
probs = torch.softmax(logits, dim=-1)
k = min(k, probs.shape[-1], int(torch.sum(probs > 0).item()))
return torch.multinomial(probs, k)

return _sample


def greedy() -> SampleFn:
def _greedy(logits: torch.Tensor, k: int) -> torch.Tensor:
k = min(k, logits.shape[-1] - int(torch.sum(torch.isinf(logits)).item()))
assert logits.ndim == 1, "expected logits to be 1D"
k = min(k, logits.shape[-1] - torch.isneginf(logits).sum().item())
return torch.topk(logits, k, dim=-1).indices

return _greedy
Expand Down

0 comments on commit befcc13

Please sign in to comment.