Skip to content

Commit

Permalink
LIT: Isolate DL runtime-specific generation logic for Transformers
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681105003
  • Loading branch information
RyanMullins authored and LIT team committed Oct 2, 2024
1 parent c86dd92 commit 0df5d9f
Showing 1 changed file with 19 additions and 38 deletions.
57 changes: 19 additions & 38 deletions lit_nlp/examples/prompt_debugging/transformers_lms.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,21 +266,29 @@ def _get_batched_outputs(
data in numpy arrays (could come from torch or tensorflow, depending on
the transformer backend).
"""
prompts = [ex["prompt"] for ex in inputs]
encoded_inputs = self.tokenizer(
prompts,
return_tensors=_HF_PYTORCH
if self.framework == MLFramework.PT
else _HF_TENSORFLOW,
[ex["prompt"] for ex in inputs],
return_tensors=(
_HF_PYTORCH if self.framework == MLFramework.PT else _HF_TENSORFLOW
),
add_special_tokens=True,
padding="longest",
truncation="longest_first",
)
batch_size, ntok_in = encoded_inputs["input_ids"].shape

if self.framework == MLFramework.PT:
encoded_inputs = encoded_inputs.to(self.device)

outputs = self.model.generate(**encoded_inputs, max_length=self.max_length)

if self.framework == MLFramework.PT:
with torch.no_grad():
# Input embeddings: <float>[batch_size, num_tokens, emb_dim]
embeddings = self.embedding_table(outputs).cpu().to(torch.float)
else:
embeddings = self.embedding_table(outputs)

if isinstance(outputs, transformers.utils.ModelOutput):
outputs = outputs.sequences

Expand All @@ -290,39 +298,12 @@ def _get_batched_outputs(
outputs[:, -ntok_out:], skip_special_tokens=True
)

if self.framework == MLFramework.PT:
with torch.no_grad():
# Input embeddings: <float>[batch_size, num_tokens, emb_dim]
embeddings = self.embedding_table(outputs)

batched_outputs = {
"embs": embeddings.cpu().to(torch.float),
"ntok_in": (
torch.sum(encoded_inputs["attention_mask"], axis=1)
.cpu()
.to(torch.int)
),
"ntok_out": torch.full(
(encoded_inputs["input_ids"].shape[0],), ntok_out
),
}
else:
embeddings = self.embedding_table(outputs)
batched_outputs = {
"embs": embeddings,
"ntok_in": tf.reduce_sum(encoded_inputs["attention_mask"], axis=1),
"ntok_out": tf.fill(
[
encoded_inputs["input_ids"].shape[0],
],
ntok_out,
),
}

# Convert to numpy for post-processing.
detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()}
detached_outputs[pd_constants.FieldNames.RESPONSE] = responses
return detached_outputs
return {
"embs": embeddings.numpy(),
"ntok_in": np.full((batch_size,), ntok_in),
"ntok_out": np.full((batch_size,), ntok_out),
pd_constants.FieldNames.RESPONSE: responses,
}

##
# LIT API implementations
Expand Down

0 comments on commit 0df5d9f

Please sign in to comment.