From 0df5d9f72361abae8710675b606fbc85f44b516f Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Tue, 1 Oct 2024 11:42:31 -0700 Subject: [PATCH] LIT: Isolate DL runtime-specific generation logic for Transformers PiperOrigin-RevId: 681105003 --- .../prompt_debugging/transformers_lms.py | 57 +++++++------------ 1 file changed, 19 insertions(+), 38 deletions(-) diff --git a/lit_nlp/examples/prompt_debugging/transformers_lms.py b/lit_nlp/examples/prompt_debugging/transformers_lms.py index e6c66deb..f43a721d 100644 --- a/lit_nlp/examples/prompt_debugging/transformers_lms.py +++ b/lit_nlp/examples/prompt_debugging/transformers_lms.py @@ -266,21 +266,29 @@ def _get_batched_outputs( data in numpy arrays (could come from torch or tensorflow, depending on the transformer backend). """ - prompts = [ex["prompt"] for ex in inputs] encoded_inputs = self.tokenizer( - prompts, - return_tensors=_HF_PYTORCH - if self.framework == MLFramework.PT - else _HF_TENSORFLOW, + [ex["prompt"] for ex in inputs], + return_tensors=( + _HF_PYTORCH if self.framework == MLFramework.PT else _HF_TENSORFLOW + ), add_special_tokens=True, padding="longest", truncation="longest_first", ) + batch_size, ntok_in = encoded_inputs["input_ids"].shape + if self.framework == MLFramework.PT: encoded_inputs = encoded_inputs.to(self.device) outputs = self.model.generate(**encoded_inputs, max_length=self.max_length) + if self.framework == MLFramework.PT: + with torch.no_grad(): + # Input embeddings: [batch_size, num_tokens, emb_dim] + embeddings = self.embedding_table(outputs).cpu().to(torch.float) + else: + embeddings = self.embedding_table(outputs) + if isinstance(outputs, transformers.utils.ModelOutput): outputs = outputs.sequences @@ -290,39 +298,12 @@ def _get_batched_outputs( outputs[:, -ntok_out:], skip_special_tokens=True ) - if self.framework == MLFramework.PT: - with torch.no_grad(): - # Input embeddings: [batch_size, num_tokens, emb_dim] - embeddings = self.embedding_table(outputs) - - batched_outputs = { - "embs": embeddings.cpu().to(torch.float), - "ntok_in": ( - torch.sum(encoded_inputs["attention_mask"], axis=1) - .cpu() - .to(torch.int) - ), - "ntok_out": torch.full( - (encoded_inputs["input_ids"].shape[0],), ntok_out - ), - } - else: - embeddings = self.embedding_table(outputs) - batched_outputs = { - "embs": embeddings, - "ntok_in": tf.reduce_sum(encoded_inputs["attention_mask"], axis=1), - "ntok_out": tf.fill( - [ - encoded_inputs["input_ids"].shape[0], - ], - ntok_out, - ), - } - - # Convert to numpy for post-processing. - detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()} - detached_outputs[pd_constants.FieldNames.RESPONSE] = responses - return detached_outputs + return { + "embs": embeddings.numpy(), + "ntok_in": np.full((batch_size,), ntok_in), + "ntok_out": np.full((batch_size,), ntok_out), + pd_constants.FieldNames.RESPONSE: responses, + } ## # LIT API implementations