From fc71b69dde02044f213e6ed11b3aa4986dcf6cc0 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Tue, 10 Sep 2024 05:55:05 -0700 Subject: [PATCH] LIT: Refactor model field names and specs into a constants module. PiperOrigin-RevId: 672928912 --- .../examples/prompt_debugging/constants.py | 44 ++++++++++ .../examples/prompt_debugging/keras_lms.py | 85 +++++++------------ .../prompt_debugging/transformers_lms.py | 78 +++++++++-------- 3 files changed, 114 insertions(+), 93 deletions(-) create mode 100644 lit_nlp/examples/prompt_debugging/constants.py diff --git a/lit_nlp/examples/prompt_debugging/constants.py b/lit_nlp/examples/prompt_debugging/constants.py new file mode 100644 index 00000000..e0be4047 --- /dev/null +++ b/lit_nlp/examples/prompt_debugging/constants.py @@ -0,0 +1,44 @@ +"""Constants used across parallel classes in the Prompt Debugging example.""" + +import types +from lit_nlp.api import types as lit_types + + +class FieldNames(types.SimpleNamespace): + PROMPT = "prompt" + RESPONSE = "response" + PROMPT_EMBEDDINGS = "prompt_embeddings" + RESPONSE_EMBEDDINGS = "response_embeddings" + TARGET = "target" + TOKENS = "tokens" + TARGET_MASK = "target_mask" + GRAD_DOT_INPUT = "grad_dot_input" + GRAD_NORM = "grad_l2" + + +INPUT_SPEC: lit_types.Spec = { + FieldNames.PROMPT: lit_types.TextSegment(), + FieldNames.TARGET: lit_types.TextSegment(required=False), +} + +INPUT_SPEC_SALIENCE: lit_types.Spec = { + FieldNames.TARGET_MASK: lit_types.TokenScores(align="", required=False), +} + +OUTPUT_SPEC_GENERATION: lit_types.Spec = { + FieldNames.RESPONSE: lit_types.GeneratedText(parent=FieldNames.TARGET) +} + +OUTPUT_SPEC_GENERATION_EMBEDDINGS: lit_types.Spec = { + FieldNames.PROMPT_EMBEDDINGS: lit_types.Embeddings(required=False), + FieldNames.RESPONSE_EMBEDDINGS: lit_types.Embeddings(required=False), +} + +OUTPUT_SPEC_TOKENIZER: lit_types.Spec = { + FieldNames.TOKENS: lit_types.Tokens(parent=""), +} + +OUTPUT_SPEC_SALIENCE: lit_types.Spec = { + FieldNames.GRAD_DOT_INPUT: lit_types.TokenScores(align=FieldNames.TOKENS), + FieldNames.GRAD_NORM: lit_types.TokenScores(align=FieldNames.TOKENS), +} | OUTPUT_SPEC_TOKENIZER diff --git a/lit_nlp/examples/prompt_debugging/keras_lms.py b/lit_nlp/examples/prompt_debugging/keras_lms.py index d42fd9f9..9c546eb5 100644 --- a/lit_nlp/examples/prompt_debugging/keras_lms.py +++ b/lit_nlp/examples/prompt_debugging/keras_lms.py @@ -3,7 +3,6 @@ from collections.abc import Sequence import functools import inspect -import types from typing import Optional from absl import logging @@ -11,6 +10,7 @@ from keras_nlp import models as keras_models from lit_nlp.api import model as lit_model from lit_nlp.api import types as lit_types +from lit_nlp.examples.prompt_debugging import constants as pd_constants from lit_nlp.examples.prompt_debugging import utils as pd_utils from lit_nlp.lib import utils as lit_utils @@ -35,19 +35,6 @@ _DEFAULT_MAX_LENGTH = 1024 -class FieldNames(types.SimpleNamespace): - PROMPT = "prompt" - RESPONSE = "response" - PROMPT_EMBEDDINGS = "prompt_embeddings" - RESPONSE_EMBEDDINGS = "response_embeddings" - TARGET = "target" - TOKENS = "tokens" - TARGET_MASK = "target_mask" - GRAD_DOT_INPUT = "grad_dot_input" - GRAD_NORM = "grad_l2" - TOKEN_LOSS = "token_loss" - - class _KerasBaseModel(lit_model.BatchedModel): """Base LIT model wrapper class for Keras on TensorFlow.""" @@ -183,10 +170,7 @@ def init_spec(cls): return None def input_spec(self): - return { - FieldNames.PROMPT: lit_types.TextSegment(), - FieldNames.TARGET: lit_types.TextSegment(required=False), - } + return pd_constants.INPUT_SPEC class KerasGenerationModel(_KerasBaseModel): @@ -240,7 +224,9 @@ def predict_minibatch( self, inputs: list[lit_types.JsonDict], ) -> list[lit_types.JsonDict]: - prompts: Sequence[str] = [ex[FieldNames.PROMPT] for ex in inputs] + prompts: Sequence[str] = [ + ex[pd_constants.FieldNames.PROMPT] for ex in inputs + ] # TODO(lit-dev): suppport loading cached responses here, since running # generation can be expensive. @@ -254,7 +240,9 @@ def predict_minibatch( for response, prompt in zip(full_responses, prompts) ] - outputs = [{FieldNames.RESPONSE: response} for response in responses] + outputs = [ + {pd_constants.FieldNames.RESPONSE: response} for response in responses + ] if self.output_embeddings: prompt_embeddings = self.embed_and_mean_pool(prompts) @@ -263,20 +251,19 @@ def predict_minibatch( response_embeddings = self.embed_and_mean_pool(responses) for o, p, r in zip(outputs, prompt_embeddings, response_embeddings): - o[FieldNames.PROMPT_EMBEDDINGS] = keras.ops.convert_to_numpy(p) - o[FieldNames.RESPONSE_EMBEDDINGS] = keras.ops.convert_to_numpy(r) + o[pd_constants.FieldNames.PROMPT_EMBEDDINGS] = ( + keras.ops.convert_to_numpy(p) + ) + o[pd_constants.FieldNames.RESPONSE_EMBEDDINGS] = ( + keras.ops.convert_to_numpy(r) + ) return outputs def output_spec(self) -> lit_types.Spec: - ret = { - FieldNames.RESPONSE: lit_types.GeneratedText(parent=FieldNames.TARGET) - } + ret = pd_constants.OUTPUT_SPEC_GENERATION if self.output_embeddings: - return ret | { - FieldNames.PROMPT_EMBEDDINGS: lit_types.Embeddings(), - FieldNames.RESPONSE_EMBEDDINGS: lit_types.Embeddings(), - } + return ret | pd_constants.OUTPUT_SPEC_GENERATION_EMBEDDINGS return ret @@ -355,11 +342,8 @@ def _pred(self, input_ids, padding_mask, target_masks): batched_outputs = { "input_ids": input_ids, "padding_mask": padding_mask, - # Gradients are already aligned to input tokens. - FieldNames.GRAD_NORM: grad_l2, - FieldNames.GRAD_DOT_INPUT: grad_dot_input, - # Shift token loss to align with (input) tokens. - # FieldNames.TOKEN_LOSS: tf.roll(per_token_loss, shift=1, axis=1), + pd_constants.FieldNames.GRAD_NORM: grad_l2, + pd_constants.FieldNames.GRAD_DOT_INPUT: grad_dot_input, } return batched_outputs @@ -447,7 +431,7 @@ def _postprocess(self, preds): """Post-process single-example preds. Operates on numpy arrays.""" mask = preds.pop("padding_mask").astype(bool) ids = preds.pop("input_ids")[mask] - preds[FieldNames.TOKENS] = self.ids_to_clean_tokens(ids) + preds[pd_constants.FieldNames.TOKENS] = self.ids_to_clean_tokens(ids) for key in lit_utils.find_spec_keys( self.output_spec(), lit_types.TokenScores ): @@ -460,13 +444,17 @@ def _postprocess(self, preds): def predict_minibatch(self, inputs): """Predict on a single minibatch of examples.""" texts: Sequence[str] = [ - ex[FieldNames.PROMPT] + ex.get(FieldNames.TARGET, "") for ex in inputs + ex[pd_constants.FieldNames.PROMPT] + + ex.get(pd_constants.FieldNames.TARGET, "") + for ex in inputs ] preprocessed_texts = self.encode_inputs(texts) sequence_ids = preprocessed_texts["token_ids"] padding_mask = preprocessed_texts["padding_mask"] - target_masks = [ex.get(FieldNames.TARGET_MASK, []) for ex in inputs] + target_masks = [ + ex.get(pd_constants.FieldNames.TARGET_MASK, []) for ex in inputs + ] # Get the predictions. batched_outputs = self._pred(sequence_ids, padding_mask, target_masks) @@ -479,19 +467,10 @@ def predict_minibatch(self, inputs): return map(self._postprocess, unbatched_outputs) def input_spec(self): - return super().input_spec() | { - FieldNames.TARGET_MASK: lit_types.TokenScores(align="", required=False), - } + return super().input_spec() | pd_constants.INPUT_SPEC_SALIENCE def output_spec(self) -> lit_types.Spec: - return { - FieldNames.TOKENS: lit_types.Tokens(parent=""), # All tokens. - FieldNames.GRAD_NORM: lit_types.TokenScores(align=FieldNames.TOKENS), - FieldNames.GRAD_DOT_INPUT: lit_types.TokenScores( - align=FieldNames.TOKENS - ), - # FieldNames.TOKEN_LOSS: lit_types.TokenScores(align=FieldNames.TOKENS), - } + return pd_constants.OUTPUT_SPEC_SALIENCE class KerasTokenizerModel(_KerasBaseModel): @@ -507,13 +486,15 @@ def _postprocess(self, preds): # rather than acting as a boolean mask. mask = preds.pop("padding_mask").astype(bool) ids = preds.pop("token_ids")[mask] - preds[FieldNames.TOKENS] = self.ids_to_clean_tokens(ids) + preds[pd_constants.FieldNames.TOKENS] = self.ids_to_clean_tokens(ids) return preds def predict_minibatch(self, inputs): """Tokenize a single minibatch of examples.""" texts: Sequence[str] = [ - ex[FieldNames.PROMPT] + ex.get(FieldNames.TARGET, "") for ex in inputs + ex[pd_constants.FieldNames.PROMPT] + + ex.get(pd_constants.FieldNames.TARGET, "") + for ex in inputs ] preprocessed_texts = self.encode_inputs(texts) batched_outputs = { @@ -529,9 +510,7 @@ def predict_minibatch(self, inputs): return map(self._postprocess, unbatched_outputs) def output_spec(self) -> lit_types.Spec: - return { - FieldNames.TOKENS: lit_types.Tokens(parent=""), # All tokens. - } + return pd_constants.OUTPUT_SPEC_TOKENIZER def initialize_model_group_for_salience( diff --git a/lit_nlp/examples/prompt_debugging/transformers_lms.py b/lit_nlp/examples/prompt_debugging/transformers_lms.py index ee60a432..8cdf603a 100644 --- a/lit_nlp/examples/prompt_debugging/transformers_lms.py +++ b/lit_nlp/examples/prompt_debugging/transformers_lms.py @@ -14,6 +14,7 @@ from absl import logging from lit_nlp.api import model as lit_model from lit_nlp.api import types as lit_types +from lit_nlp.examples.prompt_debugging import constants as pd_constants from lit_nlp.examples.prompt_debugging import utils as pd_utils from lit_nlp.lib import file_cache from lit_nlp.lib import utils @@ -179,10 +180,7 @@ def max_minibatch_size(self) -> int: return self.batch_size def input_spec(self): - return { - "prompt": lit_types.TextSegment(), - "target": lit_types.TextSegment(required=False), - } + return pd_constants.INPUT_SPEC class HFGenerativeModel(HFBaseModel): @@ -229,19 +227,23 @@ def _postprocess(self, preds: Mapping[str, Any]) -> Mapping[str, Any]: # TODO(b/324957491): return actual decoder scores for each generation. # GeneratedTextCandidates should be a list[(text, score)] processed_preds = {} - processed_preds["response"] = [(preds["response"], 1.0)] + processed_preds[pd_constants.FieldNames.RESPONSE] = [ + (preds[pd_constants.FieldNames.RESPONSE], 1.0) + ] ntok_in = preds["ntok_in"] ntok_out = preds["ntok_out"] embs = preds["embs"] assert embs.shape[0] >= ntok_in + ntok_out # Mean-pool over input tokens. - processed_preds["prompt_embeddings"] = np.mean( + processed_preds[pd_constants.FieldNames.PROMPT_EMBEDDINGS] = np.mean( embs[-(ntok_out + ntok_in) : -ntok_out], axis=0 ) # Mean-pool over output (generated) tokens. # TODO(b/324957491): slice this to only "real" output tokens, # if generation length < max generation length. - processed_preds["response_embeddings"] = np.mean(embs[-ntok_out:], axis=0) + processed_preds[pd_constants.FieldNames.RESPONSE_EMBEDDINGS] = np.mean( + embs[-ntok_out:], axis=0 + ) return processed_preds @@ -313,7 +315,7 @@ def _get_batched_outputs( # Convert to numpy for post-processing. detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()} - detached_outputs["response"] = responses + detached_outputs[pd_constants.FieldNames.RESPONSE] = responses return detached_outputs ## @@ -326,11 +328,10 @@ def predict_minibatch(self, inputs): return map(self._postprocess, unbatched_outputs) def output_spec(self) -> lit_types.Spec: - return { - "response": lit_types.GeneratedTextCandidates(parent="target"), - "prompt_embeddings": lit_types.Embeddings(required=False), - "response_embeddings": lit_types.Embeddings(required=False), - } + return ( + pd_constants.OUTPUT_SPEC_GENERATION + | pd_constants.OUTPUT_SPEC_GENERATION_EMBEDDINGS + ) class HFSalienceModel(HFBaseModel): @@ -422,11 +423,8 @@ def _pred_tf(self, encoded_inputs, target_masks): batched_outputs = { "input_ids": input_ids, "attention_mask": encoded_inputs["attention_mask"], - # Gradients are already aligned to input tokens. - "grad_l2": grad_l2, - "grad_dot_input": grad_dot_input, - # Shift token loss to align with (input) tokens. - # "token_loss": tf.roll(per_token_loss, shift=1, axis=1), + pd_constants.FieldNames.GRAD_NORM: grad_l2, + pd_constants.FieldNames.GRAD_DOT_INPUT: grad_dot_input, } return batched_outputs @@ -486,11 +484,10 @@ def _pred_pt(self, encoded_inputs, target_masks): batched_outputs = { "input_ids": input_ids.cpu().to(torch.int), "attention_mask": attention_mask.cpu().to(torch.int), - # Gradients are already aligned to input tokens. - "grad_l2": grad_l2.cpu().to(torch.float), - "grad_dot_input": grad_dot_input.cpu().to(torch.float), - # Shift token loss to align with (input) tokens. - # "token_loss": torch.roll(per_token_loss, shifts=1, dims=1), + pd_constants.FieldNames.GRAD_NORM: grad_l2.cpu().to(torch.float), + pd_constants.FieldNames.GRAD_DOT_INPUT: grad_dot_input.cpu().to( + torch.float + ), } return batched_outputs @@ -501,7 +498,7 @@ def _postprocess(self, preds): # rather than acting as a boolean mask. mask = preds.pop("attention_mask").astype(bool) ids = preds.pop("input_ids")[mask] - preds["tokens"] = self.ids_to_clean_tokens(ids) + preds[pd_constants.FieldNames.TOKENS] = self.ids_to_clean_tokens(ids) for key in utils.find_spec_keys(self.output_spec(), lit_types.TokenScores): preds[key] = preds[key][mask] # First token (usually ) is not actually predicted, so return 0 for loss. @@ -513,7 +510,11 @@ def _postprocess(self, preds): def predict_minibatch(self, inputs): """Predict on a single minibatch of examples.""" # Preprocess inputs. - texts = [ex["prompt"] + ex.get("target", "") for ex in inputs] + texts = [ + ex[pd_constants.FieldNames.PROMPT] + + ex.get(pd_constants.FieldNames.TARGET, "") + for ex in inputs + ] encoded_inputs = self.tokenizer( texts, return_tensors=_HF_PYTORCH @@ -523,7 +524,9 @@ def predict_minibatch(self, inputs): padding="longest", truncation="longest_first", ) - target_masks = [ex.get("target_mask", []) for ex in inputs] + target_masks = [ + ex.get(pd_constants.FieldNames.TARGET_MASK, []) for ex in inputs + ] # Get the predictions. if self.framework == MLFramework.PT: @@ -538,17 +541,10 @@ def predict_minibatch(self, inputs): return map(self._postprocess, unbatched_outputs) def input_spec(self): - return super().input_spec() | { - "target_mask": lit_types.TokenScores(align="", required=False), - } + return super().input_spec() | pd_constants.INPUT_SPEC_SALIENCE def output_spec(self) -> lit_types.Spec: - return { - "tokens": lit_types.Tokens(parent=""), # all tokens - "grad_l2": lit_types.TokenScores(align="tokens"), - "grad_dot_input": lit_types.TokenScores(align="tokens"), - # "token_loss": lit_types.TokenScores(align="tokens"), - } + return pd_constants.OUTPUT_SPEC_SALIENCE class HFTokenizerModel(HFBaseModel): @@ -563,14 +559,18 @@ def _postprocess(self, preds): # rather than acting as a boolean mask. mask = preds.pop("attention_mask").astype(bool) ids = preds.pop("input_ids")[mask] - preds["tokens"] = self.ids_to_clean_tokens(ids) + preds[pd_constants.FieldNames.TOKENS] = self.ids_to_clean_tokens(ids) return preds # LIT API implementations def predict_minibatch(self, inputs): """Predict on a single minibatch of examples.""" # Preprocess inputs. - texts = [ex["prompt"] + ex.get("target", "") for ex in inputs] + texts = [ + ex[pd_constants.FieldNames.PROMPT] + + ex.get(pd_constants.FieldNames.TARGET, "") + for ex in inputs + ] encoded_inputs = self.tokenizer( texts, return_tensors=_HF_PYTORCH @@ -591,9 +591,7 @@ def predict_minibatch(self, inputs): return map(self._postprocess, unbatched_outputs) def output_spec(self) -> lit_types.Spec: - return { - "tokens": lit_types.Tokens(parent=""), # all tokens - } + return pd_constants.OUTPUT_SPEC_TOKENIZER def initialize_model_group_for_salience(