Skip to content

Commit

Permalink
LIT: Fix assertion error when generation stops short of max_length
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681096247
  • Loading branch information
RyanMullins authored and LIT team committed Oct 1, 2024
1 parent 3dc61a2 commit 72e4127
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions lit_nlp/examples/prompt_debugging/transformers_lms.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,8 @@ def _get_batched_outputs(
if self.framework == MLFramework.PT:
encoded_inputs = encoded_inputs.to(self.device)

outputs = self.model.generate(
encoded_inputs["input_ids"],
attention_mask=encoded_inputs["attention_mask"],
max_length=self.max_length,
)
ntok_out = self.max_length - encoded_inputs["input_ids"].shape[1]
outputs = self.model.generate(**encoded_inputs, max_length=self.max_length)
ntok_out = outputs[0].shape[0] - encoded_inputs["input_ids"].shape[1]

responses = self.tokenizer.batch_decode(
outputs[:, -ntok_out:], skip_special_tokens=True
Expand Down

0 comments on commit 72e4127

Please sign in to comment.