From b87d3d49e89a885a50753c9d30ec08192b0288aa Mon Sep 17 00:00:00 2001 From: Bin Du Date: Wed, 18 Sep 2024 12:53:11 -0700 Subject: [PATCH] Make transformer lms HFGenerativeModel output generated texts. PiperOrigin-RevId: 676097135 --- lit_nlp/examples/prompt_debugging/transformers_lms.py | 8 ++++---- .../prompt_debugging/transformers_lms_int_test.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lit_nlp/examples/prompt_debugging/transformers_lms.py b/lit_nlp/examples/prompt_debugging/transformers_lms.py index 594ca9f1..f493926c 100644 --- a/lit_nlp/examples/prompt_debugging/transformers_lms.py +++ b/lit_nlp/examples/prompt_debugging/transformers_lms.py @@ -230,11 +230,11 @@ def _postprocess(self, preds: Mapping[str, Any]) -> Mapping[str, Any]: a dict of the processed model outputs, including the response texts and embeddings of the input and output tokens (separated into two arrays). """ - # TODO(b/324957491): return actual decoder scores for each generation. - # GeneratedTextCandidates should be a list[(text, score)] + # TODO(b/324957491): return actual decoder scores for each generation. For + # now, we only output GeneratedText. processed_preds = {} - processed_preds[pd_constants.FieldNames.RESPONSE] = [ - (preds[pd_constants.FieldNames.RESPONSE], 1.0) + processed_preds[pd_constants.FieldNames.RESPONSE] = preds[ + pd_constants.FieldNames.RESPONSE ] ntok_in = preds["ntok_in"] ntok_out = preds["ntok_out"] diff --git a/lit_nlp/examples/prompt_debugging/transformers_lms_int_test.py b/lit_nlp/examples/prompt_debugging/transformers_lms_int_test.py index 6022cfa9..9c7cde6f 100644 --- a/lit_nlp/examples/prompt_debugging/transformers_lms_int_test.py +++ b/lit_nlp/examples/prompt_debugging/transformers_lms_int_test.py @@ -77,7 +77,7 @@ def test_gpt2_generation_output(self, framework, model_path): text=cur_input["prompt"], model=model, framework=framework ) expected_output_embeddings = _get_text_mean_embeddings( - text=cur_output["response"][0][0], model=model, framework=framework + text=cur_output["response"], model=model, framework=framework ) np.testing.assert_array_almost_equal( expected_input_embeddings,