Skip to content

Commit

Permalink
LIT: Fix assertion error when generation stops short of max_length
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681096247
  • Loading branch information
RyanMullins authored and LIT team committed Oct 1, 2024
1 parent 3dc61a2 commit 219be2e
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion lit_nlp/examples/prompt_debugging/transformers_lms.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def _get_batched_outputs(
attention_mask=encoded_inputs["attention_mask"],
max_length=self.max_length,
)
ntok_out = self.max_length - encoded_inputs["input_ids"].shape[1]
ntok_out = outputs.shape[1] - encoded_inputs["input_ids"].shape[1]

responses = self.tokenizer.batch_decode(
outputs[:, -ntok_out:], skip_special_tokens=True
Expand All @@ -297,11 +297,13 @@ def _get_batched_outputs(

batched_outputs = {
"embs": embeddings.cpu().to(torch.float),
# Input tokens: <int>[batch_size]
"ntok_in": (
torch.sum(encoded_inputs["attention_mask"], axis=1)
.cpu()
.to(torch.int)
),
# Output tokens: <int>[batch_size]
"ntok_out": torch.full(
(encoded_inputs["input_ids"].shape[0],), ntok_out
),
Expand All @@ -310,7 +312,9 @@ def _get_batched_outputs(
embeddings = self.embedding_table(outputs)
batched_outputs = {
"embs": embeddings,
# Input tokens: <int>[batch_size]
"ntok_in": tf.reduce_sum(encoded_inputs["attention_mask"], axis=1),
# Output tokens: <int>[batch_size]
"ntok_out": tf.fill(
[
encoded_inputs["input_ids"].shape[0],
Expand Down

0 comments on commit 219be2e

Please sign in to comment.