diff --git a/lit_nlp/examples/prompt_debugging/transformers_lms.py b/lit_nlp/examples/prompt_debugging/transformers_lms.py index e6c66deb..7a226691 100644 --- a/lit_nlp/examples/prompt_debugging/transformers_lms.py +++ b/lit_nlp/examples/prompt_debugging/transformers_lms.py @@ -266,16 +266,17 @@ def _get_batched_outputs( data in numpy arrays (could come from torch or tensorflow, depending on the transformer backend). """ - prompts = [ex["prompt"] for ex in inputs] encoded_inputs = self.tokenizer( - prompts, - return_tensors=_HF_PYTORCH - if self.framework == MLFramework.PT - else _HF_TENSORFLOW, + [ex["prompt"] for ex in inputs], + return_tensors=( + _HF_PYTORCH if self.framework == MLFramework.PT else _HF_TENSORFLOW + ), add_special_tokens=True, padding="longest", truncation="longest_first", ) + batch_size, ntok_in = encoded_inputs["input_ids"].shape + if self.framework == MLFramework.PT: encoded_inputs = encoded_inputs.to(self.device) @@ -284,7 +285,7 @@ def _get_batched_outputs( if isinstance(outputs, transformers.utils.ModelOutput): outputs = outputs.sequences - ntok_out = outputs.shape[1] - encoded_inputs["input_ids"].shape[1] + ntok_out = outputs.shape[1] - ntok_in responses = self.tokenizer.batch_decode( outputs[:, -ntok_out:], skip_special_tokens=True @@ -293,36 +294,16 @@ def _get_batched_outputs( if self.framework == MLFramework.PT: with torch.no_grad(): # Input embeddings: [batch_size, num_tokens, emb_dim] - embeddings = self.embedding_table(outputs) - - batched_outputs = { - "embs": embeddings.cpu().to(torch.float), - "ntok_in": ( - torch.sum(encoded_inputs["attention_mask"], axis=1) - .cpu() - .to(torch.int) - ), - "ntok_out": torch.full( - (encoded_inputs["input_ids"].shape[0],), ntok_out - ), - } + embeddings = self.embedding_table(outputs).cpu().to(torch.float) else: embeddings = self.embedding_table(outputs) - batched_outputs = { - "embs": embeddings, - "ntok_in": tf.reduce_sum(encoded_inputs["attention_mask"], axis=1), - "ntok_out": tf.fill( - [ - encoded_inputs["input_ids"].shape[0], - ], - ntok_out, - ), - } - # Convert to numpy for post-processing. - detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()} - detached_outputs[pd_constants.FieldNames.RESPONSE] = responses - return detached_outputs + return { + "embs": embeddings.numpy(), + "ntok_in": np.array((batch_size, ntok_in)), + "ntok_out": np.full((batch_size,), ntok_out), + pd_constants.FieldNames.RESPONSE: responses, + } ## # LIT API implementations