Skip to content

Commit

Permalink
Make transformer lms HFGenerativeModel output generated texts.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676097135
  • Loading branch information
bdu91 authored and LIT team committed Sep 18, 2024
1 parent 66b7c8a commit b87d3d4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions lit_nlp/examples/prompt_debugging/transformers_lms.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,11 @@ def _postprocess(self, preds: Mapping[str, Any]) -> Mapping[str, Any]:
a dict of the processed model outputs, including the response texts and
embeddings of the input and output tokens (separated into two arrays).
"""
# TODO(b/324957491): return actual decoder scores for each generation.
# GeneratedTextCandidates should be a list[(text, score)]
# TODO(b/324957491): return actual decoder scores for each generation. For
# now, we only output GeneratedText.
processed_preds = {}
processed_preds[pd_constants.FieldNames.RESPONSE] = [
(preds[pd_constants.FieldNames.RESPONSE], 1.0)
processed_preds[pd_constants.FieldNames.RESPONSE] = preds[
pd_constants.FieldNames.RESPONSE
]
ntok_in = preds["ntok_in"]
ntok_out = preds["ntok_out"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_gpt2_generation_output(self, framework, model_path):
text=cur_input["prompt"], model=model, framework=framework
)
expected_output_embeddings = _get_text_mean_embeddings(
text=cur_output["response"][0][0], model=model, framework=framework
text=cur_output["response"], model=model, framework=framework
)
np.testing.assert_array_almost_equal(
expected_input_embeddings,
Expand Down

0 comments on commit b87d3d4

Please sign in to comment.