diff --git a/lit_nlp/examples/prompt_debugging/transformers_lms.py b/lit_nlp/examples/prompt_debugging/transformers_lms.py index f493926c..0fcb9768 100644 --- a/lit_nlp/examples/prompt_debugging/transformers_lms.py +++ b/lit_nlp/examples/prompt_debugging/transformers_lms.py @@ -284,7 +284,7 @@ def _get_batched_outputs( attention_mask=encoded_inputs["attention_mask"], max_length=self.max_length, ) - ntok_out = self.max_length - encoded_inputs["input_ids"].shape[1] + ntok_out = outputs.shape[1] - encoded_inputs["input_ids"].shape[1] responses = self.tokenizer.batch_decode( outputs[:, -ntok_out:], skip_special_tokens=True @@ -297,11 +297,13 @@ def _get_batched_outputs( batched_outputs = { "embs": embeddings.cpu().to(torch.float), + # Input tokens: [batch_size] "ntok_in": ( torch.sum(encoded_inputs["attention_mask"], axis=1) .cpu() .to(torch.int) ), + # Output tokens: [batch_size] "ntok_out": torch.full( (encoded_inputs["input_ids"].shape[0],), ntok_out ), @@ -310,7 +312,9 @@ def _get_batched_outputs( embeddings = self.embedding_table(outputs) batched_outputs = { "embs": embeddings, + # Input tokens: [batch_size] "ntok_in": tf.reduce_sum(encoded_inputs["attention_mask"], axis=1), + # Output tokens: [batch_size] "ntok_out": tf.fill( [ encoded_inputs["input_ids"].shape[0],