Skip to content

Commit

Permalink
LIT: Isolate DL runtime-specific generation logic for Transformers
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681555357
  • Loading branch information
RyanMullins authored and LIT team committed Oct 2, 2024
1 parent c86dd92 commit bff43f5
Showing 1 changed file with 14 additions and 33 deletions.
47 changes: 14 additions & 33 deletions lit_nlp/examples/prompt_debugging/transformers_lms.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,16 +266,17 @@ def _get_batched_outputs(
data in numpy arrays (could come from torch or tensorflow, depending on
the transformer backend).
"""
prompts = [ex["prompt"] for ex in inputs]
encoded_inputs = self.tokenizer(
prompts,
return_tensors=_HF_PYTORCH
if self.framework == MLFramework.PT
else _HF_TENSORFLOW,
[ex["prompt"] for ex in inputs],
return_tensors=(
_HF_PYTORCH if self.framework == MLFramework.PT else _HF_TENSORFLOW
),
add_special_tokens=True,
padding="longest",
truncation="longest_first",
)
batch_size, ntok_in = encoded_inputs["input_ids"].shape

if self.framework == MLFramework.PT:
encoded_inputs = encoded_inputs.to(self.device)

Expand All @@ -284,7 +285,7 @@ def _get_batched_outputs(
if isinstance(outputs, transformers.utils.ModelOutput):
outputs = outputs.sequences

ntok_out = outputs.shape[1] - encoded_inputs["input_ids"].shape[1]
ntok_out = outputs.shape[1] - ntok_in

responses = self.tokenizer.batch_decode(
outputs[:, -ntok_out:], skip_special_tokens=True
Expand All @@ -293,36 +294,16 @@ def _get_batched_outputs(
if self.framework == MLFramework.PT:
with torch.no_grad():
# Input embeddings: <float>[batch_size, num_tokens, emb_dim]
embeddings = self.embedding_table(outputs)

batched_outputs = {
"embs": embeddings.cpu().to(torch.float),
"ntok_in": (
torch.sum(encoded_inputs["attention_mask"], axis=1)
.cpu()
.to(torch.int)
),
"ntok_out": torch.full(
(encoded_inputs["input_ids"].shape[0],), ntok_out
),
}
embeddings = self.embedding_table(outputs).cpu().to(torch.float)
else:
embeddings = self.embedding_table(outputs)
batched_outputs = {
"embs": embeddings,
"ntok_in": tf.reduce_sum(encoded_inputs["attention_mask"], axis=1),
"ntok_out": tf.fill(
[
encoded_inputs["input_ids"].shape[0],
],
ntok_out,
),
}

# Convert to numpy for post-processing.
detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()}
detached_outputs[pd_constants.FieldNames.RESPONSE] = responses
return detached_outputs
return {
"embs": embeddings.numpy(),
"ntok_in": np.array((batch_size, ntok_in)),
"ntok_out": np.full((batch_size,), ntok_out),
pd_constants.FieldNames.RESPONSE: responses,
}

##
# LIT API implementations
Expand Down

0 comments on commit bff43f5

Please sign in to comment.