Skip to content

Commit

Permalink
keep min for min p
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Dec 12, 2024
1 parent 13e43b5 commit b1e907c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
14 changes: 7 additions & 7 deletions python/text_utils/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ def beam_search(
max_new_tokens is None or max_new_tokens > 0
), "max_new_tokens must be None or positive"
if stop_condition is None:
stop_condition = "max score"
stop_condition = "max_score"
assert stop_condition in {
"max score",
"estimated score",
"max outputs",
}, "stop condition must be 'max score', 'estimated score' or 'max outputs'"
"max_score",
"estimated_score",
"max_outputs",
}, "stop condition must be 'max_score', 'estimated_score' or 'max_outputs'"
batch_size = len(initial)

decoder_info: Any | None = None
Expand Down Expand Up @@ -98,7 +98,7 @@ def filter_beams() -> bool:
finished = False
continue

elif stop_condition == "max outputs":
elif stop_condition == "max_outputs":
# we are done with this batch element
# because we have enough finished beams
current_beams[idx] = []
Expand All @@ -107,7 +107,7 @@ def filter_beams() -> bool:
worst_finished = min(
(score_fn(b) for b in finished_beams[idx]), default=float("-inf")
)
if stop_condition == "estimated score":
if stop_condition == "estimated_score":
# best current calculated from current length
# idea: is a current active beam better than the worst finished beam?
best_current = max(score_fn(b) for b in current_beams[idx])
Expand Down
19 changes: 13 additions & 6 deletions python/text_utils/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,17 +288,24 @@ def _nuc(
return _nuc


def min_p_masking(min_p: float) -> LogitFn:
def min_p_masking(min_p: float, keep_min: int = 1) -> LogitFn:
assert 0.0 <= min_p <= 1.0, "min_p must be in [0, 1]"

def _min_p(
_input_ids: torch.Tensor, logits: torch.Tensor, _: list[Beam]
) -> torch.Tensor:
masked_logits = torch.full_like(logits, float("-inf"))
probs = torch.softmax(logits, dim=-1)
min_probs = probs.max(dim=-1, keepdim=True)[0] * min_p
mask = probs >= min_probs
masked_logits[mask] = logits[mask]
return masked_logits
max_probs = probs.max(dim=-1, keepdim=True)[0]

mask = probs < max_probs * min_p

sorted_indices = torch.argsort(probs, descending=True, dim=-1)
sorted_indices_mask = torch.gather(mask, dim=-1, index=sorted_indices)
sorted_indices_mask[..., :keep_min] = False

indices_mask = sorted_indices_mask.scatter(
-1, sorted_indices, sorted_indices_mask
)
return logits.masked_fill(indices_mask, float("-inf"))

return _min_p

0 comments on commit b1e907c

Please sign in to comment.