Skip to content

Commit

Permalink
Merge branch 'main' of github.com:ad-freiburg/text-utils
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Dec 13, 2024
2 parents 1d006fb + 3ff6fde commit c5581d0
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions python/text_utils/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,28 +270,31 @@ def _top_k(
return _top_k


def nucleus_masking(p: float) -> LogitFn:
def nucleus_masking(p: float, keep_min: int = 1) -> LogitFn:
assert 0.0 <= p <= 1.0, "p must be in [0, 1]"
assert keep_min > 0, "keep_min must be positive"

def _nuc(
_input_ids: torch.Tensor, logits: torch.Tensor, _: list[Beam]
) -> torch.Tensor:
keep = min(keep_min, logits.shape[-1])
probs = torch.softmax(logits, dim=-1)
sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)
sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)
nucleus = cum_sum_probs < p
nucleus = torch.cat(
[nucleus.new_ones((len(nucleus), 1)), nucleus[:, :-1]], dim=-1
[nucleus.new_ones((len(nucleus), keep)), nucleus[:, :-keep]], dim=-1
)
sorted_logits = torch.gather(logits, -1, indices)
sorted_logits = torch.gather(logits, -1, sorted_indices)
sorted_logits[torch.logical_not(nucleus)] = float("-inf")
return sorted_logits.gather(-1, indices.argsort(-1))
return sorted_logits.gather(-1, sorted_indices.argsort(-1))

return _nuc


def min_p_masking(min_p: float, keep_min: int = 1) -> LogitFn:
assert 0.0 <= min_p <= 1.0, "min_p must be in [0, 1]"
assert keep_min > 0, "keep_min must be positive"

def _min_p(
_input_ids: torch.Tensor, logits: torch.Tensor, _: list[Beam]
Expand Down

0 comments on commit c5581d0

Please sign in to comment.