Skip to content

Commit

Permalink
LIT: Refactor model field names and specs into a constants module.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 672928912
  • Loading branch information
RyanMullins authored and LIT team committed Sep 10, 2024
1 parent 75da3ef commit fc71b69
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 93 deletions.
44 changes: 44 additions & 0 deletions lit_nlp/examples/prompt_debugging/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Constants used across parallel classes in the Prompt Debugging example."""

import types
from lit_nlp.api import types as lit_types


class FieldNames(types.SimpleNamespace):
PROMPT = "prompt"
RESPONSE = "response"
PROMPT_EMBEDDINGS = "prompt_embeddings"
RESPONSE_EMBEDDINGS = "response_embeddings"
TARGET = "target"
TOKENS = "tokens"
TARGET_MASK = "target_mask"
GRAD_DOT_INPUT = "grad_dot_input"
GRAD_NORM = "grad_l2"


INPUT_SPEC: lit_types.Spec = {
FieldNames.PROMPT: lit_types.TextSegment(),
FieldNames.TARGET: lit_types.TextSegment(required=False),
}

INPUT_SPEC_SALIENCE: lit_types.Spec = {
FieldNames.TARGET_MASK: lit_types.TokenScores(align="", required=False),
}

OUTPUT_SPEC_GENERATION: lit_types.Spec = {
FieldNames.RESPONSE: lit_types.GeneratedText(parent=FieldNames.TARGET)
}

OUTPUT_SPEC_GENERATION_EMBEDDINGS: lit_types.Spec = {
FieldNames.PROMPT_EMBEDDINGS: lit_types.Embeddings(required=False),
FieldNames.RESPONSE_EMBEDDINGS: lit_types.Embeddings(required=False),
}

OUTPUT_SPEC_TOKENIZER: lit_types.Spec = {
FieldNames.TOKENS: lit_types.Tokens(parent=""),
}

OUTPUT_SPEC_SALIENCE: lit_types.Spec = {
FieldNames.GRAD_DOT_INPUT: lit_types.TokenScores(align=FieldNames.TOKENS),
FieldNames.GRAD_NORM: lit_types.TokenScores(align=FieldNames.TOKENS),
} | OUTPUT_SPEC_TOKENIZER
85 changes: 32 additions & 53 deletions lit_nlp/examples/prompt_debugging/keras_lms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from collections.abc import Sequence
import functools
import inspect
import types
from typing import Optional

from absl import logging
import keras
from keras_nlp import models as keras_models
from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
from lit_nlp.examples.prompt_debugging import constants as pd_constants
from lit_nlp.examples.prompt_debugging import utils as pd_utils
from lit_nlp.lib import utils as lit_utils

Expand All @@ -35,19 +35,6 @@
_DEFAULT_MAX_LENGTH = 1024


class FieldNames(types.SimpleNamespace):
PROMPT = "prompt"
RESPONSE = "response"
PROMPT_EMBEDDINGS = "prompt_embeddings"
RESPONSE_EMBEDDINGS = "response_embeddings"
TARGET = "target"
TOKENS = "tokens"
TARGET_MASK = "target_mask"
GRAD_DOT_INPUT = "grad_dot_input"
GRAD_NORM = "grad_l2"
TOKEN_LOSS = "token_loss"


class _KerasBaseModel(lit_model.BatchedModel):
"""Base LIT model wrapper class for Keras on TensorFlow."""

Expand Down Expand Up @@ -183,10 +170,7 @@ def init_spec(cls):
return None

def input_spec(self):
return {
FieldNames.PROMPT: lit_types.TextSegment(),
FieldNames.TARGET: lit_types.TextSegment(required=False),
}
return pd_constants.INPUT_SPEC


class KerasGenerationModel(_KerasBaseModel):
Expand Down Expand Up @@ -240,7 +224,9 @@ def predict_minibatch(
self,
inputs: list[lit_types.JsonDict],
) -> list[lit_types.JsonDict]:
prompts: Sequence[str] = [ex[FieldNames.PROMPT] for ex in inputs]
prompts: Sequence[str] = [
ex[pd_constants.FieldNames.PROMPT] for ex in inputs
]

# TODO(lit-dev): suppport loading cached responses here, since running
# generation can be expensive.
Expand All @@ -254,7 +240,9 @@ def predict_minibatch(
for response, prompt in zip(full_responses, prompts)
]

outputs = [{FieldNames.RESPONSE: response} for response in responses]
outputs = [
{pd_constants.FieldNames.RESPONSE: response} for response in responses
]

if self.output_embeddings:
prompt_embeddings = self.embed_and_mean_pool(prompts)
Expand All @@ -263,20 +251,19 @@ def predict_minibatch(
response_embeddings = self.embed_and_mean_pool(responses)

for o, p, r in zip(outputs, prompt_embeddings, response_embeddings):
o[FieldNames.PROMPT_EMBEDDINGS] = keras.ops.convert_to_numpy(p)
o[FieldNames.RESPONSE_EMBEDDINGS] = keras.ops.convert_to_numpy(r)
o[pd_constants.FieldNames.PROMPT_EMBEDDINGS] = (
keras.ops.convert_to_numpy(p)
)
o[pd_constants.FieldNames.RESPONSE_EMBEDDINGS] = (
keras.ops.convert_to_numpy(r)
)

return outputs

def output_spec(self) -> lit_types.Spec:
ret = {
FieldNames.RESPONSE: lit_types.GeneratedText(parent=FieldNames.TARGET)
}
ret = pd_constants.OUTPUT_SPEC_GENERATION
if self.output_embeddings:
return ret | {
FieldNames.PROMPT_EMBEDDINGS: lit_types.Embeddings(),
FieldNames.RESPONSE_EMBEDDINGS: lit_types.Embeddings(),
}
return ret | pd_constants.OUTPUT_SPEC_GENERATION_EMBEDDINGS
return ret


Expand Down Expand Up @@ -355,11 +342,8 @@ def _pred(self, input_ids, padding_mask, target_masks):
batched_outputs = {
"input_ids": input_ids,
"padding_mask": padding_mask,
# Gradients are already aligned to input tokens.
FieldNames.GRAD_NORM: grad_l2,
FieldNames.GRAD_DOT_INPUT: grad_dot_input,
# Shift token loss to align with (input) tokens.
# FieldNames.TOKEN_LOSS: tf.roll(per_token_loss, shift=1, axis=1),
pd_constants.FieldNames.GRAD_NORM: grad_l2,
pd_constants.FieldNames.GRAD_DOT_INPUT: grad_dot_input,
}

return batched_outputs
Expand Down Expand Up @@ -447,7 +431,7 @@ def _postprocess(self, preds):
"""Post-process single-example preds. Operates on numpy arrays."""
mask = preds.pop("padding_mask").astype(bool)
ids = preds.pop("input_ids")[mask]
preds[FieldNames.TOKENS] = self.ids_to_clean_tokens(ids)
preds[pd_constants.FieldNames.TOKENS] = self.ids_to_clean_tokens(ids)
for key in lit_utils.find_spec_keys(
self.output_spec(), lit_types.TokenScores
):
Expand All @@ -460,13 +444,17 @@ def _postprocess(self, preds):
def predict_minibatch(self, inputs):
"""Predict on a single minibatch of examples."""
texts: Sequence[str] = [
ex[FieldNames.PROMPT] + ex.get(FieldNames.TARGET, "") for ex in inputs
ex[pd_constants.FieldNames.PROMPT]
+ ex.get(pd_constants.FieldNames.TARGET, "")
for ex in inputs
]
preprocessed_texts = self.encode_inputs(texts)
sequence_ids = preprocessed_texts["token_ids"]
padding_mask = preprocessed_texts["padding_mask"]

target_masks = [ex.get(FieldNames.TARGET_MASK, []) for ex in inputs]
target_masks = [
ex.get(pd_constants.FieldNames.TARGET_MASK, []) for ex in inputs
]

# Get the predictions.
batched_outputs = self._pred(sequence_ids, padding_mask, target_masks)
Expand All @@ -479,19 +467,10 @@ def predict_minibatch(self, inputs):
return map(self._postprocess, unbatched_outputs)

def input_spec(self):
return super().input_spec() | {
FieldNames.TARGET_MASK: lit_types.TokenScores(align="", required=False),
}
return super().input_spec() | pd_constants.INPUT_SPEC_SALIENCE

def output_spec(self) -> lit_types.Spec:
return {
FieldNames.TOKENS: lit_types.Tokens(parent=""), # All tokens.
FieldNames.GRAD_NORM: lit_types.TokenScores(align=FieldNames.TOKENS),
FieldNames.GRAD_DOT_INPUT: lit_types.TokenScores(
align=FieldNames.TOKENS
),
# FieldNames.TOKEN_LOSS: lit_types.TokenScores(align=FieldNames.TOKENS),
}
return pd_constants.OUTPUT_SPEC_SALIENCE


class KerasTokenizerModel(_KerasBaseModel):
Expand All @@ -507,13 +486,15 @@ def _postprocess(self, preds):
# rather than acting as a boolean mask.
mask = preds.pop("padding_mask").astype(bool)
ids = preds.pop("token_ids")[mask]
preds[FieldNames.TOKENS] = self.ids_to_clean_tokens(ids)
preds[pd_constants.FieldNames.TOKENS] = self.ids_to_clean_tokens(ids)
return preds

def predict_minibatch(self, inputs):
"""Tokenize a single minibatch of examples."""
texts: Sequence[str] = [
ex[FieldNames.PROMPT] + ex.get(FieldNames.TARGET, "") for ex in inputs
ex[pd_constants.FieldNames.PROMPT]
+ ex.get(pd_constants.FieldNames.TARGET, "")
for ex in inputs
]
preprocessed_texts = self.encode_inputs(texts)
batched_outputs = {
Expand All @@ -529,9 +510,7 @@ def predict_minibatch(self, inputs):
return map(self._postprocess, unbatched_outputs)

def output_spec(self) -> lit_types.Spec:
return {
FieldNames.TOKENS: lit_types.Tokens(parent=""), # All tokens.
}
return pd_constants.OUTPUT_SPEC_TOKENIZER


def initialize_model_group_for_salience(
Expand Down
Loading

0 comments on commit fc71b69

Please sign in to comment.