From e7115ab6758d54bb8afe79ffd2cd6c6e5786a835 Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Wed, 27 Sep 2023 06:28:51 -0700 Subject: [PATCH 01/51] Internal change. PiperOrigin-RevId: 568832193 --- website/sphinx_src/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/sphinx_src/index.md b/website/sphinx_src/index.md index 60d3b441..78a63a98 100644 --- a/website/sphinx_src/index.md +++ b/website/sphinx_src/index.md @@ -1,6 +1,6 @@ # Learning Interpretability Tool (LIT) - + From 8a3f366816833ead164ecfca778b465ef6d074bb Mon Sep 17 00:00:00 2001 From: Bin Du Date: Wed, 4 Oct 2023 11:49:16 -0700 Subject: [PATCH 02/51] Make multi label prediction scores visible in the data table column. PiperOrigin-RevId: 570758611 --- lit_nlp/client/services/data_service.ts | 37 +++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/lit_nlp/client/services/data_service.ts b/lit_nlp/client/services/data_service.ts index 2fff8abc..8b59100c 100644 --- a/lit_nlp/client/services/data_service.ts +++ b/lit_nlp/client/services/data_service.ts @@ -19,7 +19,7 @@ import {action, computed, observable, reaction} from 'mobx'; import {BINARY_NEG_POS, ColorRange} from '../lib/colors'; -import {BooleanLitType, CategoryLabel, GeneratedText, GeneratedTextCandidates, LitType, MulticlassPreds, RegressionScore, Scalar} from '../lib/lit_types'; +import {BooleanLitType, CategoryLabel, GeneratedText, GeneratedTextCandidates, LitType, MulticlassPreds, RegressionScore, Scalar, SparseMultilabelPreds} from '../lib/lit_types'; import {ClassificationResults, IndexedInput, RegressionResults} from '../lib/types'; import {createLitType, findSpecKeys, isLitSubtype, mapsContainSame} from '../lib/utils'; @@ -68,6 +68,8 @@ export const GEN_TEXT_CANDS_SOURCE_PREFIX = 'GeneratedTextCandidates'; export const REGRESSION_SOURCE_PREFIX = 'Regression'; /** Column source prefix for columns from scalar model outputs. */ export const SCALAR_SOURCE_PREFIX = 'Scalar'; +/** Column source prefix for columns from multilabel model outputs. */ +export const MULTILABEL_SOURCE_PREFIX = 'Multilabel'; /** * Data service singleton, responsible for maintaining columns of computed data @@ -109,7 +111,7 @@ export class DataService extends LitService { } }, {fireImmediately: true}); - // Run other preiction interpreters when necessary. + // Run other prediction interpreters when necessary. const getPredictionInputs = () => [this.appState.currentInputData, this.appState.currentModels]; reaction(getPredictionInputs, () => { @@ -124,6 +126,7 @@ export class DataService extends LitService { this.runGeneratedTextPreds(model, this.appState.currentInputData); this.runRegression(model, this.appState.currentInputData); this.runScalarPreds(model, this.appState.currentInputData); + this.runMultiLabelPreds(model, this.appState.currentInputData); } }, {fireImmediately: true}); @@ -301,6 +304,36 @@ export class DataService extends LitService { } } + /** + * Run multi label predictions and store results in data service. + */ + private async runMultiLabelPreds(model: string, data: IndexedInput[]) { + const {output} = this.appState.getModelSpec(model); + if (findSpecKeys(output, SparseMultilabelPreds).length === 0) { + return; + } + + const multiLabelPredsPromise = this.apiService.getPreds( + data, model, this.appState.currentDataset, [SparseMultilabelPreds]); + const preds = await multiLabelPredsPromise; + + // Add multi label prediction results as new column to the data service. + if (preds == null || preds.length === 0) { + return; + } + const multiLabelPredKeys = Object.keys(preds[0]); + for (const key of multiLabelPredKeys) { + const scoreFeatName = this.getColumnName(model, key); + const scores = preds.map(pred => pred[key]); + // TODO(b/303457849): maybe possible to directly use the data type from + // the output spec rather than creating a new one. + const dataType = createLitType(SparseMultilabelPreds); + const source = `${MULTILABEL_SOURCE_PREFIX}:${model}`; + this.addColumnFromList( + scores, data, key, scoreFeatName, dataType, source); + } + } + @action async setValuesForNewDatapoints(datapoints: IndexedInput[]) { // When new datapoints are created, set their data values for each From e63b67484fc7f4dbfa3484126c355350d2127bf7 Mon Sep 17 00:00:00 2001 From: Cibi Arjun Date: Mon, 16 Oct 2023 11:07:57 -0700 Subject: [PATCH 03/51] Numeric fields with defaults that evaluate to False are considered a missing field. Buttons to add the example are therefore not rendered and an error message is displayed instead. PiperOrigin-RevId: 573868488 --- lit_nlp/client/modules/datapoint_editor_module.ts | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/lit_nlp/client/modules/datapoint_editor_module.ts b/lit_nlp/client/modules/datapoint_editor_module.ts index 7c1b7736..c80f5d66 100644 --- a/lit_nlp/client/modules/datapoint_editor_module.ts +++ b/lit_nlp/client/modules/datapoint_editor_module.ts @@ -99,8 +99,18 @@ export class DatapointEditorModule extends LitModule { @computed get missingFields(): string[] { + // Check only for initial values that are null or undefined for numeric + // values to allow for defaults that evaluate to false. return this.appState.currentModelRequiredInputSpecKeys.filter( - key => !this.editedData.data[key]); + (key: string) => { + if (this.groupService.numericalFeatureNames.includes(key)) { + return this.editedData.data[key] === undefined || + this.editedData.data[key] === null; + } + else { + return !this.editedData.data[key]; + } + }); } @computed From d2745088966c4ac31a3755f55096eeb8193c5a91 Mon Sep 17 00:00:00 2001 From: Cibi Arjun Date: Fri, 20 Oct 2023 17:00:10 -0700 Subject: [PATCH 04/51] Introduce expanded default range for integer as validation steps in datapoint editor now reduce the range of values previously supported by default PiperOrigin-RevId: 575351072 --- lit_nlp/api/testdata/count_examples.indexed.lit.jsonl.spec | 4 ++-- lit_nlp/api/testdata/count_examples.lit.jsonl.spec | 4 ++-- lit_nlp/api/types.py | 2 ++ lit_nlp/client/lib/lit_types.ts | 2 ++ website/sphinx_src/api.md | 5 ++--- 5 files changed, 10 insertions(+), 7 deletions(-) diff --git a/lit_nlp/api/testdata/count_examples.indexed.lit.jsonl.spec b/lit_nlp/api/testdata/count_examples.indexed.lit.jsonl.spec index 7ac049cf..7ad4ebfb 100644 --- a/lit_nlp/api/testdata/count_examples.indexed.lit.jsonl.spec +++ b/lit_nlp/api/testdata/count_examples.indexed.lit.jsonl.spec @@ -20,8 +20,8 @@ "value": { "required": true, "annotated": false, - "min_val": 0, - "max_val": 1, + "min_val": -32768, + "max_val": 32767, "default": 0, "step": 1, "__class__": "LitType", diff --git a/lit_nlp/api/testdata/count_examples.lit.jsonl.spec b/lit_nlp/api/testdata/count_examples.lit.jsonl.spec index 7ac049cf..7ad4ebfb 100644 --- a/lit_nlp/api/testdata/count_examples.lit.jsonl.spec +++ b/lit_nlp/api/testdata/count_examples.lit.jsonl.spec @@ -20,8 +20,8 @@ "value": { "required": true, "annotated": false, - "min_val": 0, - "max_val": 1, + "min_val": -32768, + "max_val": 32767, "default": 0, "step": 1, "__class__": "LitType", diff --git a/lit_nlp/api/types.py b/lit_nlp/api/types.py index f3ea60c5..153a9110 100644 --- a/lit_nlp/api/types.py +++ b/lit_nlp/api/types.py @@ -917,6 +917,8 @@ class MetricResult(LitType): @attr.s(auto_attribs=True, frozen=True, kw_only=True) class Integer(Scalar): step: int = 1 + min_val: int = -32768 + max_val: int = 32767 @attr.s(auto_attribs=True, frozen=True, kw_only=True) diff --git a/lit_nlp/client/lib/lit_types.ts b/lit_nlp/client/lib/lit_types.ts index c334301e..d966dd75 100644 --- a/lit_nlp/client/lib/lit_types.ts +++ b/lit_nlp/client/lib/lit_types.ts @@ -232,6 +232,8 @@ export class Scalar extends LitType { */ @registered export class Integer extends Scalar { + override min_val = -32768; + override max_val = 32767; override step = 1; } diff --git a/website/sphinx_src/api.md b/website/sphinx_src/api.md index 977837ea..a174bf39 100644 --- a/website/sphinx_src/api.md +++ b/website/sphinx_src/api.md @@ -807,8 +807,7 @@ _See the [examples](https://github.com/PAIR-code/lit/blob/main/lit_nlp/examples) ### Available types -The full set of `LitType`s is defined in -[types.py](https://github.com/PAIR-code/lit/blob/main/lit_nlp/api/types.py), and summarized +The full set of `LitType`s is defined in [types.py](https://github.com/PAIR-code/lit/blob/main/lit_nlp/api/types.py). Numeric types such as `Integer` and `Scalar` have predefined ranges that can be overridden using corresponding `min_val` and `max_val` attributes as seen [here](https://github.com/PAIR-code/lit/blob/main/lit_nlp/examples/datasets/penguin_data.py;l=19-22;rcl=574999438). The different types available in LIT are summarized in the table below. Note: Bracket syntax, such as `[num_tokens]`, refers to the shapes of @@ -828,7 +827,7 @@ Name | Description `TokenTopKPreds` | Predicted tokens and their scores, as from a language model or seq2seq model. | `list[list[tuple[str, float]]]` `Boolean` | Boolean value. | `bool` `Scalar` | Scalar numeric value. | `float` -`Integer` | Integer value. | `int` +`Integer` | Integer, with a default range from -32768 to +32767. value. | `int` `ImageBytes` | Image, represented by a base64 encoded string. LIT also provides `JPEGBytes` and `PNGBytes` types for those specific encodings. | `str` `RegressionScore` | Scalar value, treated as a regression target or prediction. | `float` `ReferenceScores` | Scores for one or more reference texts. | `list[float]` From a21986342d83ae64d58607e337fab9db7736242a Mon Sep 17 00:00:00 2001 From: Cibi Arjun Date: Mon, 23 Oct 2023 15:22:30 -0700 Subject: [PATCH 05/51] Override layouts instead of merging to allow opt-out from default set of options PiperOrigin-RevId: 575944519 --- lit_nlp/app.py | 6 +----- lit_nlp/examples/coref/coref_demo.py | 2 +- lit_nlp/examples/custom_module/potato_demo.py | 8 ++++---- lit_nlp/examples/dalle/demo.py | 5 +++-- lit_nlp/examples/image_demo.py | 2 ++ lit_nlp/examples/is_eval/is_eval_demo.py | 2 +- lit_nlp/examples/lm_demo.py | 3 ++- lit_nlp/examples/penguin_demo.py | 2 +- 8 files changed, 15 insertions(+), 15 deletions(-) diff --git a/lit_nlp/app.py b/lit_nlp/app.py index 2ce66896..a83eabb7 100644 --- a/lit_nlp/app.py +++ b/lit_nlp/app.py @@ -899,11 +899,7 @@ def __init__( id_hash_fn=caching.input_hash, ) - # TODO(lit-dev): override layouts instead of merging, to allow clients - # to opt-out of the default bundled layouts. This will require updating - # client code to manually merge when this is the desired behavior. - self._layouts = dict(layout.DEFAULT_LAYOUTS, **(layouts or {})) - + self._layouts = layouts if layouts else layout.DEFAULT_LAYOUTS self._model_loaders: ModelLoadersMap = model_loaders or {} self._models: dict[str, caching.CachingModelWrapper] = {} for name, model in models.items(): diff --git a/lit_nlp/examples/coref/coref_demo.py b/lit_nlp/examples/coref/coref_demo.py index cb395b17..d960909c 100644 --- a/lit_nlp/examples/coref/coref_demo.py +++ b/lit_nlp/examples/coref/coref_demo.py @@ -123,7 +123,7 @@ }, description="Custom layout for the Winogender coreference demo.", ) -CUSTOM_LAYOUTS = {"winogender": WINOGENDER_LAYOUT} +CUSTOM_LAYOUTS = layout.DEFAULT_LAYOUTS | {"winogender": WINOGENDER_LAYOUT} FLAGS.set_default("default_layout", "winogender") diff --git a/lit_nlp/examples/custom_module/potato_demo.py b/lit_nlp/examples/custom_module/potato_demo.py index 25c6b17a..098fd2e6 100644 --- a/lit_nlp/examples/custom_module/potato_demo.py +++ b/lit_nlp/examples/custom_module/potato_demo.py @@ -57,6 +57,8 @@ description="Custom layout with our spud-tastic potato module.", ) +CUSTOM_LAYOUTS = layout.DEFAULT_LAYOUTS | {"potato": POTATO_LAYOUT} + def get_wsgi_app() -> Optional[dev_server.LitServerType]: """Returns a LitApp instance for consumption by gunicorn.""" @@ -86,10 +88,8 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server( - models, - datasets, - layouts={"potato": POTATO_LAYOUT}, - **server_flags.get_flags()) + models, datasets, layouts=CUSTOM_LAYOUTS, **server_flags.get_flags() + ) return lit_demo.serve() diff --git a/lit_nlp/examples/dalle/demo.py b/lit_nlp/examples/dalle/demo.py index 1a466d4b..ea049178 100644 --- a/lit_nlp/examples/dalle/demo.py +++ b/lit_nlp/examples/dalle/demo.py @@ -58,7 +58,8 @@ }, description="Custom layout for Text to Image models.", ) -_CUSTOM_LAYOUTS = {"DALLE_LAYOUT": _DALLE_LAYOUT} + +CUSTOM_LAYOUTS = layout.DEFAULT_LAYOUTS | {"DALLE_LAYOUT": _DALLE_LAYOUT} def get_wsgi_app() -> Optional[dev_server.LitServerType]: @@ -92,7 +93,7 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: lit_demo = dev_server.Server( models, datasets, - layouts=_CUSTOM_LAYOUTS, + layouts=CUSTOM_LAYOUTS, **server_flags.get_flags(), ) return lit_demo.serve() diff --git a/lit_nlp/examples/image_demo.py b/lit_nlp/examples/image_demo.py index 459ef6d6..5d9d25a4 100644 --- a/lit_nlp/examples/image_demo.py +++ b/lit_nlp/examples/image_demo.py @@ -66,6 +66,8 @@ def get_wsgi_app(): description='Basic layout for image demo', ) +CUSTOM_LAYOUTS = layout.DEFAULT_LAYOUTS | {'default': DEMO_LAYOUT} + def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: if len(argv) > 1: diff --git a/lit_nlp/examples/is_eval/is_eval_demo.py b/lit_nlp/examples/is_eval/is_eval_demo.py index b75abe65..0621adef 100644 --- a/lit_nlp/examples/is_eval/is_eval_demo.py +++ b/lit_nlp/examples/is_eval/is_eval_demo.py @@ -91,7 +91,7 @@ ], }, description="Custom layout for evaluating input salience methods.") -CUSTOM_LAYOUTS = {"is_eval": IS_EVAL_LAYOUT} +CUSTOM_LAYOUTS = layout.DEFAULT_LAYOUTS | {"is_eval": IS_EVAL_LAYOUT} # You can change this back via URL param, e.g. localhost:5432/?layout=default FLAGS.set_default("default_layout", "is_eval") diff --git a/lit_nlp/examples/lm_demo.py b/lit_nlp/examples/lm_demo.py index b1cef5ab..e6b3d422 100644 --- a/lit_nlp/examples/lm_demo.py +++ b/lit_nlp/examples/lm_demo.py @@ -86,7 +86,8 @@ }, description="Custom layout for language models.", ) -CUSTOM_LAYOUTS = {"lm": LM_LAYOUT} + +CUSTOM_LAYOUTS = layout.DEFAULT_LAYOUTS | {"lm": LM_LAYOUT} # You can also change this via URL param e.g. localhost:5432/?layout=default FLAGS.set_default("default_layout", "lm") diff --git a/lit_nlp/examples/penguin_demo.py b/lit_nlp/examples/penguin_demo.py index 109f9ea9..79cc077c 100644 --- a/lit_nlp/examples/penguin_demo.py +++ b/lit_nlp/examples/penguin_demo.py @@ -49,7 +49,7 @@ lower=layout.STANDARD_LAYOUT.lower, description='Custom layout for the Palmer Penguins demo.', ) -CUSTOM_LAYOUTS = {'penguins': PENGUIN_LAYOUT} +CUSTOM_LAYOUTS = layout.DEFAULT_LAYOUTS | {'penguins': PENGUIN_LAYOUT} # Function for running demo through gunicorn instead of the local dev server. From 688a8cdad9c4f7c9e57fd3eab4e1d6cbbdff2fac Mon Sep 17 00:00:00 2001 From: Cibi Arjun Date: Fri, 27 Oct 2023 14:05:52 -0700 Subject: [PATCH 06/51] Add dependencies for scatter_gl PiperOrigin-RevId: 577299901 --- lit_nlp/examples/custom_module/potato_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lit_nlp/examples/custom_module/potato_demo.py b/lit_nlp/examples/custom_module/potato_demo.py index 098fd2e6..89eef2e7 100644 --- a/lit_nlp/examples/custom_module/potato_demo.py +++ b/lit_nlp/examples/custom_module/potato_demo.py @@ -4,7 +4,7 @@ It also uses a custom frontend build, which has a fun potato module! To run locally: - python -m lit_nlp.examples.potato_demo --port=5432 + blaze run -c opt --config=cuda examples/custom_module:potato_demo -- --port=5432 Once you see the ASCII-art LIT logo, navigate to localhost:5432 to access the demo UI. From f254fa8500d6267278fa3dc32fb4bbf56beb7cf7 Mon Sep 17 00:00:00 2001 From: Bin Du Date: Mon, 6 Nov 2023 13:39:38 -0800 Subject: [PATCH 07/51] Bump jaxlib version to 0.4.13 and flax version to 0.6.11. PiperOrigin-RevId: 579944240 --- lit_nlp/examples/dalle/requirements.txt | 6 +++--- requirements_examples.txt | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lit_nlp/examples/dalle/requirements.txt b/lit_nlp/examples/dalle/requirements.txt index 3eeff39a..d431c2dd 100644 --- a/lit_nlp/examples/dalle/requirements.txt +++ b/lit_nlp/examples/dalle/requirements.txt @@ -15,9 +15,9 @@ -r ../../../requirements_core.txt -jax==0.3.25 -jaxlib==0.3.25 -flax==0.6.3 +jax==0.4.13 +jaxlib==0.4.13 +flax==0.6.11 dalle-mini==0.1.5 ipywidgets==7.5.0 orbax==0.0.23 diff --git a/requirements_examples.txt b/requirements_examples.txt index 87d542b2..9362ad5d 100644 --- a/requirements_examples.txt +++ b/requirements_examples.txt @@ -18,7 +18,7 @@ sentencepiece==0.1.99 tensorflow-datasets==4.8.0 torch==2.0.1 transformers==4.27.1 -jax==0.3.16 -jaxlib==0.3.15 -flax==0.5.3 +jax==0.4.13 +jaxlib==0.4.13 +flax==0.6.11 # LINT.ThenChange(./pyproject.toml) From d4302bd6bfc7e4c778ba0e96397ac620242a8d21 Mon Sep 17 00:00:00 2001 From: Bin Du Date: Mon, 6 Nov 2023 13:57:50 -0800 Subject: [PATCH 08/51] Annotate new data points created from counterfactual generator before adding to the data table. PiperOrigin-RevId: 579949562 --- lit_nlp/client/modules/generator_module.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lit_nlp/client/modules/generator_module.ts b/lit_nlp/client/modules/generator_module.ts index 87b9bbb9..b5fed37a 100644 --- a/lit_nlp/client/modules/generator_module.ts +++ b/lit_nlp/client/modules/generator_module.ts @@ -205,7 +205,7 @@ export class GeneratorModule extends LitModule { } private async createNewDatapoints(data: IndexedInput[][]) { - const newExamples = flatten(data); + const newExamples = await this.appState.annotateNewData(flatten(data)); this.appState.commitNewDatapoints(newExamples); const newIds = newExamples.map(d => d.id); if (newIds.length === 0) return; From 7bcdb192e5032f40e143872715a5c6a3c353f763 Mon Sep 17 00:00:00 2001 From: James Wexler Date: Tue, 7 Nov 2023 11:28:31 -0800 Subject: [PATCH 09/51] Fix LIT rendering in colab PiperOrigin-RevId: 580246237 --- lit_nlp/notebook.py | 43 +++++-------------------------------------- 1 file changed, 5 insertions(+), 38 deletions(-) diff --git a/lit_nlp/notebook.py b/lit_nlp/notebook.py index 57610e70..87a051cd 100644 --- a/lit_nlp/notebook.py +++ b/lit_nlp/notebook.py @@ -22,15 +22,16 @@ from lit_nlp.api import layout from lit_nlp.lib import wsgi_serving +is_colab = False try: import google.colab # pylint: disable=g-import-not-at-top,unused-import + from google.colab import output # pylint: disable=g-import-not-at-top,unused-import # pytype: disable=import-error is_colab = True # Can disable import error as this package is always # included in colab kernels. from colabtools import interactive_widgets # pylint: disable=g-import-not-at-top # pytype: disable=import-error progress_indicator = interactive_widgets.ProgressIter except (ImportError, ModuleNotFoundError): - is_colab = False from tqdm import notebook # pylint: disable=g-import-not-at-top progress_indicator = notebook.tqdm @@ -167,46 +168,12 @@ def _display_colab(port, height, open_in_new_tab, ui_params: RenderConfig): """ params = ui_params.get_query_str() + path = f'/{params}' if open_in_new_tab: - shell = """ - (async () => { - const proxyPort = await google.colab.kernel.proxyPort( - %PORT%, {'cache': true}) - const url = new URL(proxyPort + '%PARAMS%') - const a = document.createElement('a'); - a.href = "javascript:void(0);" - a.onclick = (e) => window.open(url, "_blank"); - a.innerHTML = url; - document.body.appendChild(a); - window.open(url, "_blank"); - })(); - """ + output.serve_kernel_port_as_window(port, path=path) else: - shell = """ - (async () => { - const proxyPort = await google.colab.kernel.proxyPort( - %PORT%, {'cache': true}) - const url = new URL(proxyPort + '%PARAMS%') - const iframe = document.createElement('iframe'); - iframe.src = url; - iframe.setAttribute('width', '100%'); - iframe.setAttribute('height', '%HEIGHT%px'); - iframe.setAttribute('frameborder', 0); - document.body.appendChild(iframe); - })(); - """ - - replacements = [ - ('%PORT%', '%d' % port), - ('%HEIGHT%', '%d' % height), - ('%PARAMS%', '%s' % params), - ] - for (k, v) in replacements: - shell = shell.replace(k, v) - - script = display.Javascript(shell) - display.display(script) + output.serve_kernel_port_as_iframe(port, height=f'{height}', path=path) def _display_jupyter( From 0dcb31df8539aec2eeefa097322ce86cc2feb41e Mon Sep 17 00:00:00 2001 From: Cibi Arjun Date: Tue, 21 Nov 2023 12:50:26 -0800 Subject: [PATCH 10/51] Internal Change PiperOrigin-RevId: 584404376 --- lit_nlp/api/dtypes.py | 8 ++++++++ lit_nlp/api/types.py | 7 +++++++ lit_nlp/client/lib/dtypes.ts | 5 +++++ lit_nlp/client/lib/lit_types.ts | 10 +++++++++- lit_nlp/client/modules/salience_map_module.ts | 9 ++++----- 5 files changed, 33 insertions(+), 6 deletions(-) diff --git a/lit_nlp/api/dtypes.py b/lit_nlp/api/dtypes.py index dbb6df82..efd8504a 100644 --- a/lit_nlp/api/dtypes.py +++ b/lit_nlp/api/dtypes.py @@ -112,6 +112,14 @@ class FeatureSalience(DataTuple): salience: dict[str, float] +@attr.s(auto_attribs=True, frozen=True, slots=True) +class FrameSalience(DataTuple): + """Dataclass for a salience map over image frames in a video.""" + + # A map of salience score and image string by frame number + salience: dict[str, tuple[float, str]] + + # TODO(b/196886684): document API for salience interpreters. @attr.s(auto_attribs=True, frozen=True, slots=True) class SequenceSalienceMap(DataTuple): diff --git a/lit_nlp/api/types.py b/lit_nlp/api/types.py index 153a9110..2ba9d274 100644 --- a/lit_nlp/api/types.py +++ b/lit_nlp/api/types.py @@ -849,6 +849,13 @@ class FeatureSalience(Salience): default: dtypes.FeatureSalience = None +@attr.s(auto_attribs=True, frozen=True, kw_only=True) +class FrameSalience(Salience): + """Metadata about a returned frame salience map.""" + + default: dtypes.FrameSalience = None + + @attr.s(auto_attribs=True, frozen=True, kw_only=True) class ImageSalience(Salience): """Metadata about a returned image saliency. diff --git a/lit_nlp/client/lib/dtypes.ts b/lit_nlp/client/lib/dtypes.ts index 9ab6acd9..bc8f7a6a 100644 --- a/lit_nlp/client/lib/dtypes.ts +++ b/lit_nlp/client/lib/dtypes.ts @@ -60,6 +60,11 @@ export interface FeatureSalience extends DataTuple { salience: {[key: string]: number}; } +/** Dataclass for a salience map over image frames in a video. */ +export interface FrameSalience extends DataTuple { + salience: {[key: string]: [number, string]}; +} + // TODO(b/196886684): document API for salience interpreters. /** Dataclass for a salience map over a target sequence. */ export interface SequenceSalienceMap extends DataTuple { diff --git a/lit_nlp/client/lib/lit_types.ts b/lit_nlp/client/lib/lit_types.ts index d966dd75..3a1aa570 100644 --- a/lit_nlp/client/lib/lit_types.ts +++ b/lit_nlp/client/lib/lit_types.ts @@ -20,7 +20,7 @@ // their Python counterparts. // tslint:disable:no-new-decorators class-name enforce-name-casing -import {AnnotationCluster, EdgeLabel, FeatureSalience as FeatureSalienceDType, ScoredTextCandidates, SequenceSalienceMap, SpanLabel, TokenSalience as TokenSalienceDType} from './dtypes'; +import {AnnotationCluster, EdgeLabel, FeatureSalience as FeatureSalienceDType, ScoredTextCandidates, SequenceSalienceMap, SpanLabel, TokenSalience as TokenSalienceDType, FrameSalience as FrameSalienceDType} from './dtypes'; /** * A dictionary of registered LitType names mapped to their constructor. @@ -544,6 +544,14 @@ export class FeatureSalience extends Salience { override default: FeatureSalienceDType|undefined = undefined; } +/** + * Metadata about a returned frame salience map. + */ +@registered +export class FrameSalience extends Salience { + override default: FrameSalienceDType|undefined = undefined; +} + /** * Metadata about a returned image saliency. * The data is returned as an image in the base64 URL encoded format, e.g., diff --git a/lit_nlp/client/modules/salience_map_module.ts b/lit_nlp/client/modules/salience_map_module.ts index 68992c98..a36e6430 100644 --- a/lit_nlp/client/modules/salience_map_module.ts +++ b/lit_nlp/client/modules/salience_map_module.ts @@ -35,7 +35,7 @@ import {LitModule} from '../core/lit_module'; import {LegendType} from '../elements/color_legend'; import {InterpreterClick} from '../elements/interpreter_controls'; import {TokenWithWeight} from '../elements/token_chips'; -import {FeatureSalience, FieldMatcher, ImageGradients, ImageSalience, LitTypeTypesList, LitTypeWithParent, MulticlassPreds, RegressionScore, Salience, SalienceTargetInfo, TokenGradients, TokenSalience} from '../lib/lit_types'; +import {FeatureSalience, FieldMatcher, ImageGradients, ImageSalience, LitTypeTypesList, LitTypeWithParent, MulticlassPreds, RegressionScore, Salience, SalienceTargetInfo, TokenGradients, TokenSalience, FrameSalience} from '../lib/lit_types'; import {styles as sharedStyles} from '../lib/shared_styles.css'; import {CallConfig, IndexedInput, ModelInfoMap, Preds, SCROLL_SYNC_CSS_CLASS, Spec} from '../lib/types'; import {argmax, cloneSpec, findSpecKeys, makeModifiedInput} from '../lib/utils'; @@ -64,12 +64,11 @@ interface FeatureSalienceResult { [key: string]: {salience: FeatureSalienceMap}; } -type SalienceResult = TokenSalienceResult | ImageSalienceResult | +/** Different types of salience results supported in this module. */ +export type SalienceResult = TokenSalienceResult | ImageSalienceResult | FeatureSalienceResult; -// Notably, not SequenceSalience as that is handled by a different module. -const SUPPORTED_SALIENCE_TYPES = - [TokenSalience, FeatureSalience, ImageSalience]; +const SUPPORTED_SALIENCE_TYPES = [TokenSalience, FeatureSalience, ImageSalience, FrameSalience]; const TARGET_SELECTOR_SUPPORTED_TYPES: LitTypeTypesList = [MulticlassPreds, RegressionScore]; From ac8ed5902a2c96019ea1137b5138d48017fabf4e Mon Sep 17 00:00:00 2001 From: Bin Du Date: Tue, 28 Nov 2023 11:05:49 -0800 Subject: [PATCH 11/51] `SimpleSentimentModel` should inherit from `BatchedModel`. This class implements the abstract method `predict_minibatch` in `BatchedModel`. Addressing https://github.com/PAIR-code/lit/issues/1361. PiperOrigin-RevId: 586041904 --- lit_nlp/examples/simple_pytorch_demo.py | 4 ++-- lit_nlp/examples/sst_pytorch_demo.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lit_nlp/examples/simple_pytorch_demo.py b/lit_nlp/examples/simple_pytorch_demo.py index 2a925895..cde642db 100644 --- a/lit_nlp/examples/simple_pytorch_demo.py +++ b/lit_nlp/examples/simple_pytorch_demo.py @@ -79,7 +79,7 @@ def _from_pretrained(cls, *args, **kw): return cls.from_pretrained(*args, from_tf=True, **kw) -class SimpleSentimentModel(lit_model.Model): +class SimpleSentimentModel(lit_model.BatchedModel): """Simple sentiment analysis model.""" LABELS = ["0", "1"] # negative, positive @@ -103,7 +103,7 @@ def __init__(self, model_name_or_path): ## # LIT API implementation def max_minibatch_size(self): - # This tells lit_model.Model.predict() how to batch inputs to + # This tells lit_model.BatchedModel.predict() how to batch inputs to # predict_minibatch(). # Alternately, you can just override predict() and handle batching yourself. return 32 diff --git a/lit_nlp/examples/sst_pytorch_demo.py b/lit_nlp/examples/sst_pytorch_demo.py index fd341b37..dede8a61 100644 --- a/lit_nlp/examples/sst_pytorch_demo.py +++ b/lit_nlp/examples/sst_pytorch_demo.py @@ -70,7 +70,7 @@ def _from_pretrained(cls, *args, **kw): return cls.from_pretrained(*args, from_tf=True, **kw) -class SimpleSentimentModel(lit_model.Model): +class SimpleSentimentModel(lit_model.BatchedModel): """Simple sentiment analysis model.""" LABELS = ["0", "1"] # negative, positive @@ -95,7 +95,7 @@ def __init__(self, model_name_or_path): ## # LIT API implementation def max_minibatch_size(self): - # This tells lit_model.Model.predict() how to batch inputs to + # This tells lit_model.BatchedModel.predict() how to batch inputs to # predict_minibatch(). # Alternately, you can just override predict() and handle batching yourself. return 32 From 724bdee1f9ea45ce998b9031eea4ad1169299efb Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Wed, 6 Dec 2023 11:25:13 -0800 Subject: [PATCH 12/51] Move TyDi QA demo to its own directory. PiperOrigin-RevId: 588490476 --- .gitignore | 1 + lit_nlp/examples/dalle/README.md | 4 +-- lit_nlp/examples/dalle/demo.py | 3 +- lit_nlp/examples/models/tydi_test.py | 30 ------------------- lit_nlp/examples/tydi/README.md | 26 ++++++++++++++++ .../examples/{tydi_demo.py => tydi/demo.py} | 7 ++--- .../{models/tydi.py => tydi/model.py} | 0 lit_nlp/examples/tydi/requirements.txt | 20 +++++++++++++ requirements_examples.txt | 3 -- 9 files changed, 53 insertions(+), 41 deletions(-) delete mode 100644 lit_nlp/examples/models/tydi_test.py create mode 100644 lit_nlp/examples/tydi/README.md rename lit_nlp/examples/{tydi_demo.py => tydi/demo.py} (92%) rename lit_nlp/examples/{models/tydi.py => tydi/model.py} (100%) create mode 100644 lit_nlp/examples/tydi/requirements.txt diff --git a/.gitignore b/.gitignore index 0ec4e4f6..1388772e 100644 --- a/.gitignore +++ b/.gitignore @@ -15,5 +15,6 @@ docs/documentation/.doctrees/** **/.DS_Store .dalle-venv/ +.tydi-venv/ .venv/ .vscode/ diff --git a/lit_nlp/examples/dalle/README.md b/lit_nlp/examples/dalle/README.md index 71cd296c..52a53c54 100644 --- a/lit_nlp/examples/dalle/README.md +++ b/lit_nlp/examples/dalle/README.md @@ -21,8 +21,8 @@ LIT repo. python -m venv .dalle-venv source .dalle-venv/bin/activate # This requirements.txt file will also install the core LIT library deps. -pip install -r ./lit_nlp/examples/dalle-mini/requirements.txt -# The LIT web app can still needs be built in the usual way. +pip install -r ./lit_nlp/examples/dalle/requirements.txt +# The LIT web app still needs to be built in the usual way. (cd ./lit_nlp && yarn && yarn build) ``` diff --git a/lit_nlp/examples/dalle/demo.py b/lit_nlp/examples/dalle/demo.py index ea049178..a96f4abd 100644 --- a/lit_nlp/examples/dalle/demo.py +++ b/lit_nlp/examples/dalle/demo.py @@ -1,8 +1,7 @@ r"""Example for dalle demo model. To run locally with a small number of examples: - python -m lit_nlp.examples.dalle_demo \ - --alsologtostderr --port=5432 + python -m lit_nlp.examples.dalle.demo Then navigate to localhost:5432 to access the demo UI. """ diff --git a/lit_nlp/examples/models/tydi_test.py b/lit_nlp/examples/models/tydi_test.py deleted file mode 100644 index df37691a..00000000 --- a/lit_nlp/examples/models/tydi_test.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Tests for lit_nlp.examples.models.tydi.""" - -from absl.testing import absltest -from lit_nlp.api import types as lit_types -from lit_nlp.examples.models import tydi - - -class TyDiModelTest(absltest.TestCase): - """Test that model classes conform to the expected spec.""" - - def test_model_specs(self): - model = tydi.TyDiModel("TyDiModel", model="dummy", tokenizer="dummy") - # Check inputs - ispec = model.input_spec() - self.assertIn("context", ispec) - self.assertIsInstance(ispec["context"], lit_types.TextSegment) - self.assertIn("answers_text", ispec) - self.assertIsInstance( - ispec["answers_text"], lit_types.MultiSegmentAnnotations - ) - - # Check outputs - ospec = model.output_spec() - self.assertIn("generated_text", ospec) - self.assertIsInstance(ospec["generated_text"], lit_types.GeneratedText) - self.assertEqual(ospec["generated_text"].parent, "answers_text") - - -if __name__ == "__main__": - absltest.main() diff --git a/lit_nlp/examples/tydi/README.md b/lit_nlp/examples/tydi/README.md new file mode 100644 index 00000000..f8b3508d --- /dev/null +++ b/lit_nlp/examples/tydi/README.md @@ -0,0 +1,26 @@ +TyDi QA Demo for the Learning Interpretability Tool +======================================================= + +This demo showcases how LIT can be used to a multilingual question-answering +model trained on the [TyDi QA dataset](https://doi.org/10.1162/tacl_a_00317) +using FLAX. + +You will need a stand-alone virtual environment for the Python libraries, which you can set up using the following commands from the root of the LIT repo. + +```sh +# Create the virtual environment. You may want to use python3 or python3.10 +# depends on how many Python versions you have installed and their aliases. +python -m venv .tydi-venv +source .tydi-venv/bin/activate +# This requirements.txt file will also install the core LIT library deps. +pip install -r ./lit_nlp/examples/tydi/requirements.txt +# The LIT web app still needs to be built in the usual way. +(cd ./lit_nlp && yarn && yarn build) +``` + +Once your virtual environment is setup, you can launch the demo with the +following command. + +```sh +python -m lit_nlp.examples.tydi.demo +``` diff --git a/lit_nlp/examples/tydi_demo.py b/lit_nlp/examples/tydi/demo.py similarity index 92% rename from lit_nlp/examples/tydi_demo.py rename to lit_nlp/examples/tydi/demo.py index 2c043ab0..02739c5b 100644 --- a/lit_nlp/examples/tydi_demo.py +++ b/lit_nlp/examples/tydi/demo.py @@ -1,8 +1,7 @@ r"""Example demo loading a TyDiModel. To run locally with a small number of examples: - python -m lit_nlp.examples.tydi_demo \ - --alsologtostderr --port=5432 --max_examples=10 + python -m lit_nlp.examples.tydi.demo Then navigate to localhost:5432 to access the demo UI. """ @@ -18,7 +17,7 @@ from lit_nlp import server_flags from lit_nlp.components import word_replacer from lit_nlp.examples.datasets import question_answering -from lit_nlp.examples.models import tydi +from lit_nlp.examples.tydi import model # NOTE: additional flags defined in server_flags.py _FLAGS = flags.FLAGS @@ -56,7 +55,7 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: # Ignore path prefix, if using /path/to/ to load from a # specific directory rather than the default shortcut. model_name = os.path.basename(model_name_or_path) - models[model_name] = tydi.TyDiModel(model_name=model_name_or_path) + models[model_name] = model.TyDiModel(model_name=model_name_or_path) max_examples: int = _MAX_EXAMPLES.value dataset_defs: tuple[tuple[str, str]] = ( diff --git a/lit_nlp/examples/models/tydi.py b/lit_nlp/examples/tydi/model.py similarity index 100% rename from lit_nlp/examples/models/tydi.py rename to lit_nlp/examples/tydi/model.py diff --git a/lit_nlp/examples/tydi/requirements.txt b/lit_nlp/examples/tydi/requirements.txt new file mode 100644 index 00000000..184b9971 --- /dev/null +++ b/lit_nlp/examples/tydi/requirements.txt @@ -0,0 +1,20 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +-r ../../../requirements_core.txt + +jax==0.3.16 +jaxlib==0.3.15 +flax==0.5.3 diff --git a/requirements_examples.txt b/requirements_examples.txt index 9362ad5d..d3353493 100644 --- a/requirements_examples.txt +++ b/requirements_examples.txt @@ -18,7 +18,4 @@ sentencepiece==0.1.99 tensorflow-datasets==4.8.0 torch==2.0.1 transformers==4.27.1 -jax==0.4.13 -jaxlib==0.4.13 -flax==0.6.11 # LINT.ThenChange(./pyproject.toml) From a381cb97f04920424847af7bef3fb3143a3e2c03 Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Mon, 11 Dec 2023 09:57:25 -0800 Subject: [PATCH 13/51] Fix multiprocessing import. PiperOrigin-RevId: 589856836 --- lit_nlp/api/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lit_nlp/api/model.py b/lit_nlp/api/model.py index 6acc4b3c..e8b29363 100644 --- a/lit_nlp/api/model.py +++ b/lit_nlp/api/model.py @@ -17,7 +17,7 @@ from collections.abc import Iterable, Iterator import inspect import itertools -import multiprocessing # for ThreadPool +import multiprocessing.pool # for ThreadPool from typing import Optional, Union from absl import logging From 1d511ba728083a7692c26012e0c0dca4c7a01c60 Mon Sep 17 00:00:00 2001 From: Cibi Arjun Date: Tue, 2 Jan 2024 15:20:08 -0800 Subject: [PATCH 14/51] Internal Change PiperOrigin-RevId: 595222708 --- lit_nlp/api/dtypes.py | 4 +- lit_nlp/client/elements/frames_window.css | 13 ++++++ lit_nlp/client/elements/frames_window.ts | 52 +++++++++++++++++++++++ lit_nlp/client/elements/table_types.ts | 2 +- lit_nlp/client/lib/dtypes.ts | 2 +- 5 files changed, 69 insertions(+), 4 deletions(-) create mode 100644 lit_nlp/client/elements/frames_window.css create mode 100644 lit_nlp/client/elements/frames_window.ts diff --git a/lit_nlp/api/dtypes.py b/lit_nlp/api/dtypes.py index efd8504a..2e036763 100644 --- a/lit_nlp/api/dtypes.py +++ b/lit_nlp/api/dtypes.py @@ -116,8 +116,8 @@ class FeatureSalience(DataTuple): class FrameSalience(DataTuple): """Dataclass for a salience map over image frames in a video.""" - # A map of salience score and image string by frame number - salience: dict[str, tuple[float, str]] + # A map of salience score and image bytes string by frame number + salience: dict[str, tuple[float, Sequence[str]]] # TODO(b/196886684): document API for salience interpreters. diff --git a/lit_nlp/client/elements/frames_window.css b/lit_nlp/client/elements/frames_window.css new file mode 100644 index 00000000..1eebaef1 --- /dev/null +++ b/lit_nlp/client/elements/frames_window.css @@ -0,0 +1,13 @@ +.frames-window-image { + flex: 16.5%; + padding: 5px; +} + +#frame { + max-width: 202px; + max-height: 360px; +} + +.frames-window { + display: flex; +} diff --git a/lit_nlp/client/elements/frames_window.ts b/lit_nlp/client/elements/frames_window.ts new file mode 100644 index 00000000..e7b169c8 --- /dev/null +++ b/lit_nlp/client/elements/frames_window.ts @@ -0,0 +1,52 @@ +/** + * LIT module for displaying a variable size window of image frames. + */ + + + +import {html, LitElement} from 'lit'; +import {customElement, property} from 'lit/decorators.js'; +import {styles as sharedStyles} from '../lib/shared_styles.css'; + +import {styles} from './frames_window.css'; + +/** + * A LIT module to display variable size list of image frames within a window. + */ +@customElement('lit-frames-window') +export class FramesWindow extends LitElement { + @property({type: Array}) frames: string[] = []; + + static override get styles() { + return [ + sharedStyles, + styles, + ]; + } + + + private renderImage(imgSrc: string) { + return html` +
+ +
`; + } + + + override render() { + const framesDOM = + this.frames.map((imageSrc: string) => this.renderImage(imageSrc)); + + return html` +
+ ${framesDOM} +
`; + } +} + + +declare global { + interface HTMLElementTagNameMap { + 'lit-frames-window': FramesWindow; + } +} diff --git a/lit_nlp/client/elements/table_types.ts b/lit_nlp/client/elements/table_types.ts index dd6d16ad..11d7e5fc 100644 --- a/lit_nlp/client/elements/table_types.ts +++ b/lit_nlp/client/elements/table_types.ts @@ -36,7 +36,7 @@ export interface SortableTemplateResult { value: SortableTableEntry; } /** Wrapper types for the data supplied to the data table */ -export type TableEntry = string|number|TemplateResult|SortableTemplateResult; +export type TableEntry = string|number|string[]|TemplateResult|SortableTemplateResult; /** Wrapper types for the rows of data supplied to the data table */ export type TableData = TableEntry[]|{[key: string]: TableEntry}; diff --git a/lit_nlp/client/lib/dtypes.ts b/lit_nlp/client/lib/dtypes.ts index bc8f7a6a..4a18a1dd 100644 --- a/lit_nlp/client/lib/dtypes.ts +++ b/lit_nlp/client/lib/dtypes.ts @@ -62,7 +62,7 @@ export interface FeatureSalience extends DataTuple { /** Dataclass for a salience map over image frames in a video. */ export interface FrameSalience extends DataTuple { - salience: {[key: string]: [number, string]}; + salience: {[key: string]: [number, string[]]}; } // TODO(b/196886684): document API for salience interpreters. From fb9ffbad6715510d11cffd69fc1f08b7a5ef6dc6 Mon Sep 17 00:00:00 2001 From: Cibi Arjun Date: Wed, 10 Jan 2024 12:21:45 -0800 Subject: [PATCH 15/51] Internal Change PiperOrigin-RevId: 597322685 --- lit_nlp/client/modules/salience_map_module.ts | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/lit_nlp/client/modules/salience_map_module.ts b/lit_nlp/client/modules/salience_map_module.ts index a36e6430..cdd8d473 100644 --- a/lit_nlp/client/modules/salience_map_module.ts +++ b/lit_nlp/client/modules/salience_map_module.ts @@ -60,7 +60,10 @@ interface FeatureSalienceMap { [feature: string]: number; } -interface FeatureSalienceResult { +/** + * Results for calls to fetch salience for features. + */ +export interface FeatureSalienceResult { [key: string]: {salience: FeatureSalienceMap}; } @@ -665,7 +668,7 @@ export class SalienceMapModule extends LitModule { // the label and the expander toggle. // clang-format off return html` -
+
@@ -675,6 +678,8 @@ export class SalienceMapModule extends LitModule {
${this.selectionService.primarySelectedInputData != null ? + // TODO(b/319297222) Modify element such that we render each + // feature in the feature salience module in a separate line. salienceContent : html` Select a datapoint to see ${name} attributions. From 8122fc927bdd93d3cdea619114aebcfbf1f3c469 Mon Sep 17 00:00:00 2001 From: Jan Kuehle Date: Tue, 16 Jan 2024 08:22:04 -0800 Subject: [PATCH 16/51] Automated Code Change PiperOrigin-RevId: 598847797 --- lit_nlp/client/core/faceting_control.ts | 2 +- lit_nlp/client/core/global_settings.ts | 2 +- lit_nlp/client/elements/color_legend.ts | 2 +- lit_nlp/client/elements/generated_text_vis.ts | 4 ++-- lit_nlp/client/elements/interpreter_controls.ts | 2 +- lit_nlp/client/elements/table.ts | 2 +- lit_nlp/client/services/classification_service.ts | 2 +- lit_nlp/client/services/data_service.ts | 2 +- lit_nlp/client/services/modules_service.ts | 2 +- lit_nlp/client/services/selection_service.ts | 2 +- lit_nlp/client/services/slice_service.ts | 2 +- lit_nlp/client/services/state_service.ts | 2 +- 12 files changed, 13 insertions(+), 13 deletions(-) diff --git a/lit_nlp/client/core/faceting_control.ts b/lit_nlp/client/core/faceting_control.ts index 7a3d78c1..91dbfd77 100644 --- a/lit_nlp/client/core/faceting_control.ts +++ b/lit_nlp/client/core/faceting_control.ts @@ -29,7 +29,7 @@ import {app} from '../core/app'; import {ReactiveElement} from '../lib/elements'; import {styles as sharedStyles} from '../lib/shared_styles.css'; import {getStepSizeGivenRange} from '../lib/utils'; -import {FacetingConfig, FacetingMethod, GroupService, NumericFeatureBins} from '../services/group_service'; +import {FacetingConfig, FacetingMethod, GroupService, type NumericFeatureBins} from '../services/group_service'; import {styles} from './faceting_control.css'; diff --git a/lit_nlp/client/core/global_settings.ts b/lit_nlp/client/core/global_settings.ts index 1d9d7b10..923d2f23 100644 --- a/lit_nlp/client/core/global_settings.ts +++ b/lit_nlp/client/core/global_settings.ts @@ -36,7 +36,7 @@ import {action, computed, observable} from 'mobx'; import {styles as sharedStyles} from '../lib/shared_styles.css'; import {StringLitType} from '../lib/lit_types'; -import {CallConfig, datasetDisplayName, LitTabGroupLayout, NONE_DS_DICT_KEY, Spec} from '../lib/types'; +import {type CallConfig, datasetDisplayName, LitTabGroupLayout, NONE_DS_DICT_KEY, Spec} from '../lib/types'; import {getTemplateStringFromMarkdown, validateCallConfig} from '../lib/utils'; import {LitInputField} from '../elements/lit_input_field'; import {resolveModuleConfig} from '../services/modules_service'; diff --git a/lit_nlp/client/elements/color_legend.ts b/lit_nlp/client/elements/color_legend.ts index 9e8f972e..9c48beca 100644 --- a/lit_nlp/client/elements/color_legend.ts +++ b/lit_nlp/client/elements/color_legend.ts @@ -28,7 +28,7 @@ import {computed, observable} from 'mobx'; import {DEFAULT} from '../lib/colors'; import {ReactiveElement} from '../lib/elements'; import {styles as sharedStyles} from '../lib/shared_styles.css'; -import {D3Scale} from '../lib/types'; +import {type D3Scale} from '../lib/types'; import {getTextWidth, linearSpace} from '../lib/utils'; import {styles} from './color_legend.css'; diff --git a/lit_nlp/client/elements/generated_text_vis.ts b/lit_nlp/client/elements/generated_text_vis.ts index 6f79ffb4..23dbdea3 100644 --- a/lit_nlp/client/elements/generated_text_vis.ts +++ b/lit_nlp/client/elements/generated_text_vis.ts @@ -22,9 +22,9 @@ import {classMap} from 'lit/directives/class-map.js'; import {computed, observable} from 'mobx'; import {ReactiveElement} from '../lib/elements'; -import {DiffMode, getTextDiff, TextDiff} from '../lib/generated_text_utils'; +import {DiffMode, getTextDiff, type TextDiff} from '../lib/generated_text_utils'; import {styles as sharedStyles} from '../lib/shared_styles.css'; -import {ScoredTextCandidates} from '../lib/dtypes'; +import {type ScoredTextCandidates} from '../lib/dtypes'; import {styles} from './generated_text_vis.css'; diff --git a/lit_nlp/client/elements/interpreter_controls.ts b/lit_nlp/client/elements/interpreter_controls.ts index 8a548b18..c3b4e110 100644 --- a/lit_nlp/client/elements/interpreter_controls.ts +++ b/lit_nlp/client/elements/interpreter_controls.ts @@ -26,7 +26,7 @@ import {computed, observable} from 'mobx'; import {ReactiveElement} from '../lib/elements'; import {BooleanLitType, CategoryLabel, LitType, LitTypeWithVocab, MultiFieldMatcher, Scalar, SingleFieldMatcher, SparseMultilabel, Tokens} from '../lib/lit_types'; import {styles as sharedStyles} from '../lib/shared_styles.css'; -import {Spec} from '../lib/types'; +import {type Spec} from '../lib/types'; import {getTemplateStringFromMarkdown} from '../lib/utils'; import {styles} from './interpreter_controls.css'; diff --git a/lit_nlp/client/elements/table.ts b/lit_nlp/client/elements/table.ts index b63071d9..2645b472 100644 --- a/lit_nlp/client/elements/table.ts +++ b/lit_nlp/client/elements/table.ts @@ -53,7 +53,7 @@ export type TemplateResultFn = TemplateResult; /** Export types from ./table_types. */ -export {ColumnHeader, SortableTableEntry, SortableTemplateResult, TableData, TableEntry, TableRowInternal}; +export {type ColumnHeader, type SortableTableEntry, type SortableTemplateResult, type TableData, type TableEntry, type TableRowInternal}; /** Callback for selection */ export type OnSelectCallback = (selectedIndices: number[]) => void; diff --git a/lit_nlp/client/services/classification_service.ts b/lit_nlp/client/services/classification_service.ts index 438295be..6c28a761 100644 --- a/lit_nlp/client/services/classification_service.ts +++ b/lit_nlp/client/services/classification_service.ts @@ -19,7 +19,7 @@ import {action, computed, observable, reaction} from 'mobx'; import {MulticlassPreds} from '../lib/lit_types'; -import {FacetedData, GroupedExamples, SpecMap} from '../lib/types'; +import {type FacetedData, type GroupedExamples, type SpecMap} from '../lib/types'; import {getMarginFromThreshold} from '../lib/utils'; import {LitService} from './lit_service'; diff --git a/lit_nlp/client/services/data_service.ts b/lit_nlp/client/services/data_service.ts index 8b59100c..114fae3d 100644 --- a/lit_nlp/client/services/data_service.ts +++ b/lit_nlp/client/services/data_service.ts @@ -18,7 +18,7 @@ // tslint:disable:no-new-decorators import {action, computed, observable, reaction} from 'mobx'; -import {BINARY_NEG_POS, ColorRange} from '../lib/colors'; +import {BINARY_NEG_POS, type ColorRange} from '../lib/colors'; import {BooleanLitType, CategoryLabel, GeneratedText, GeneratedTextCandidates, LitType, MulticlassPreds, RegressionScore, Scalar, SparseMultilabelPreds} from '../lib/lit_types'; import {ClassificationResults, IndexedInput, RegressionResults} from '../lib/types'; import {createLitType, findSpecKeys, isLitSubtype, mapsContainSame} from '../lib/utils'; diff --git a/lit_nlp/client/services/modules_service.ts b/lit_nlp/client/services/modules_service.ts index 618feba4..6cdbc22b 100644 --- a/lit_nlp/client/services/modules_service.ts +++ b/lit_nlp/client/services/modules_service.ts @@ -18,7 +18,7 @@ // tslint:disable:no-new-decorators import {action, observable} from 'mobx'; -import {LayoutSettings, LitCanonicalLayout, LitComponentSpecifier, LitModuleClass, LitModuleConfig, LitTabGroupLayout, ModelInfoMap, ResolvedModuleConfig, Spec} from '../lib/types'; +import {LayoutSettings, type LitCanonicalLayout, LitComponentSpecifier, LitModuleClass, LitModuleConfig, LitTabGroupLayout, type ModelInfoMap, ResolvedModuleConfig, type Spec} from '../lib/types'; import {LitService} from './lit_service'; import {ModulesObservedByUrlService, UrlConfiguration} from './url_service'; diff --git a/lit_nlp/client/services/selection_service.ts b/lit_nlp/client/services/selection_service.ts index db3163a2..91ae92db 100644 --- a/lit_nlp/client/services/selection_service.ts +++ b/lit_nlp/client/services/selection_service.ts @@ -18,7 +18,7 @@ // tslint:disable:no-new-decorators import {action, computed, observable} from 'mobx'; -import {IndexedInput, ServiceUser} from '../lib/types'; +import {type IndexedInput, type ServiceUser} from '../lib/types'; import {LitService} from './lit_service'; import {SelectionObservedByUrlService} from './url_service'; diff --git a/lit_nlp/client/services/slice_service.ts b/lit_nlp/client/services/slice_service.ts index f97babeb..7c3fb200 100644 --- a/lit_nlp/client/services/slice_service.ts +++ b/lit_nlp/client/services/slice_service.ts @@ -18,7 +18,7 @@ // tslint:disable:no-new-decorators import {action, computed, observable, reaction} from 'mobx'; -import {IndexedInput, ServiceUser} from '../lib/types'; +import {IndexedInput, type ServiceUser} from '../lib/types'; import {arrayContainsSame} from '../lib/utils'; import {LitService} from './lit_service'; diff --git a/lit_nlp/client/services/state_service.ts b/lit_nlp/client/services/state_service.ts index e721247d..ee55daa3 100644 --- a/lit_nlp/client/services/state_service.ts +++ b/lit_nlp/client/services/state_service.ts @@ -19,7 +19,7 @@ import {action, computed, observable, toJS} from 'mobx'; import {FieldMatcher, ImageBytes} from '../lib/lit_types'; -import {defaultValueByField, IndexedInput, Input, LitCanonicalLayout, LitComponentLayouts, LitMetadata, ModelInfo, ModelInfoMap, ModelSpec, NONE_DS_DICT_KEY, Spec} from '../lib/types'; +import {defaultValueByField, IndexedInput, Input, type LitCanonicalLayout, type LitComponentLayouts, type LitMetadata, ModelInfo, type ModelInfoMap, ModelSpec, NONE_DS_DICT_KEY, type Spec} from '../lib/types'; import {findSpecKeys, getTypes} from '../lib/utils'; import {ApiService} from './api_service'; From 243a058db08995e5d8c789444a104dc852528eb0 Mon Sep 17 00:00:00 2001 From: Jan Kuehle Date: Tue, 16 Jan 2024 08:23:18 -0800 Subject: [PATCH 17/51] Automated Code Change PiperOrigin-RevId: 598848097 --- lit_nlp/client/modules/annotated_text_module.ts | 4 ++-- lit_nlp/client/modules/confusion_matrix_module.ts | 2 +- lit_nlp/client/modules/curves_module.ts | 2 +- lit_nlp/client/modules/data_table_module.ts | 2 +- lit_nlp/client/modules/datapoint_editor_module.ts | 2 +- lit_nlp/client/modules/feature_attribution_module.ts | 2 +- lit_nlp/client/modules/generated_text_module.ts | 4 ++-- lit_nlp/client/modules/generator_module.ts | 2 +- lit_nlp/client/modules/lm_prediction_module.ts | 2 +- lit_nlp/client/modules/metrics_module.ts | 2 +- lit_nlp/client/modules/regression_module.ts | 2 +- lit_nlp/client/modules/salience_map_module.ts | 2 +- lit_nlp/client/modules/sequence_salience_module.ts | 4 ++-- lit_nlp/client/modules/tda_module.ts | 4 ++-- 14 files changed, 18 insertions(+), 18 deletions(-) diff --git a/lit_nlp/client/modules/annotated_text_module.ts b/lit_nlp/client/modules/annotated_text_module.ts index 800438c4..286d8c76 100644 --- a/lit_nlp/client/modules/annotated_text_module.ts +++ b/lit_nlp/client/modules/annotated_text_module.ts @@ -19,9 +19,9 @@ import { html} from 'lit'; import {observable} from 'mobx'; import {LitModule} from '../core/lit_module'; -import {AnnotationGroups, TextSegments} from '../elements/annotated_text_vis'; +import {type AnnotationGroups, TextSegments} from '../elements/annotated_text_vis'; import {MultiSegmentAnnotations, TextSegment} from '../lib/lit_types'; -import {IndexedInput, ModelInfoMap, Spec} from '../lib/types'; +import {type IndexedInput, ModelInfoMap, Spec} from '../lib/types'; import {doesOutputSpecContain, filterToKeys, findSpecKeys} from '../lib/utils'; import {styles as sharedStyles} from '../lib/shared_styles.css'; diff --git a/lit_nlp/client/modules/confusion_matrix_module.ts b/lit_nlp/client/modules/confusion_matrix_module.ts index fd9af769..1c8fcea6 100644 --- a/lit_nlp/client/modules/confusion_matrix_module.ts +++ b/lit_nlp/client/modules/confusion_matrix_module.ts @@ -29,7 +29,7 @@ import {MulticlassPreds} from '../lib/lit_types'; import {styles as sharedStyles} from '../lib/shared_styles.css'; import {GroupedExamples, IndexedInput, ModelInfoMap} from '../lib/types'; import {arrayContainsSame, doesOutputSpecContain, facetMapToDictKey} from '../lib/utils'; -import {FacetingMethod, GetFeatureFunc, GroupService, NumericFeatureBins} from '../services/group_service'; +import {FacetingMethod, GetFeatureFunc, GroupService, type NumericFeatureBins} from '../services/group_service'; import {DataService, SelectionService} from '../services/services'; import {styles} from './confusion_matrix_module.css'; diff --git a/lit_nlp/client/modules/curves_module.ts b/lit_nlp/client/modules/curves_module.ts index 0d428b57..27650837 100644 --- a/lit_nlp/client/modules/curves_module.ts +++ b/lit_nlp/client/modules/curves_module.ts @@ -28,7 +28,7 @@ import {FacetsChange} from '../core/faceting_control'; import {LitModule} from '../core/lit_module'; import {MulticlassPreds} from '../lib/lit_types'; import {styles as sharedStyles} from '../lib/shared_styles.css'; -import {GroupedExamples, IndexedInput, ModelInfoMap, SCROLL_SYNC_CSS_CLASS, Spec} from '../lib/types'; +import {type GroupedExamples, IndexedInput, ModelInfoMap, SCROLL_SYNC_CSS_CLASS, Spec} from '../lib/types'; import {findSpecKeys, hasValidParent} from '../lib/utils'; import {NumericFeatureBins} from '../services/group_service'; import {GroupService} from '../services/services'; diff --git a/lit_nlp/client/modules/data_table_module.ts b/lit_nlp/client/modules/data_table_module.ts index bf41ce31..7f6a1a43 100644 --- a/lit_nlp/client/modules/data_table_module.ts +++ b/lit_nlp/client/modules/data_table_module.ts @@ -30,7 +30,7 @@ import {LitModule} from '../core/lit_module'; import {ColumnHeader, DataTable, SortableTemplateResult, TableData, TableEntry} from '../elements/table'; import {BooleanLitType, LitType, LitTypeWithVocab, URLLitType} from '../lib/lit_types'; import {styles as sharedStyles} from '../lib/shared_styles.css'; -import {formatForDisplay, IndexedInput, ModelInfoMap, Spec} from '../lib/types'; +import {formatForDisplay, IndexedInput, ModelInfoMap, type Spec} from '../lib/types'; import {compareArrays} from '../lib/utils'; import {DataService, FocusService, SelectionService, SliceService} from '../services/services'; import {STARRED_SLICE_NAME} from '../services/slice_service'; diff --git a/lit_nlp/client/modules/datapoint_editor_module.ts b/lit_nlp/client/modules/datapoint_editor_module.ts index c80f5d66..a27460e3 100644 --- a/lit_nlp/client/modules/datapoint_editor_module.ts +++ b/lit_nlp/client/modules/datapoint_editor_module.ts @@ -30,7 +30,7 @@ import {LitModule} from '../core/lit_module'; import {AnnotationCluster, EdgeLabel, SpanLabel} from '../lib/dtypes'; import {BooleanLitType, EdgeLabels, Embeddings, ImageBytes, ListLitType, LitTypeWithVocab, MultiSegmentAnnotations, Scalar, SearchQuery, SequenceTags, SpanLabels, SparseMultilabel, StringLitType, Tokens, URLLitType} from '../lib/lit_types'; import {styles as sharedStyles} from '../lib/shared_styles.css'; -import {formatAnnotationCluster, formatEdgeLabel, formatSpanLabel, IndexedInput, Input, ModelInfoMap, SCROLL_SYNC_CSS_CLASS, Spec} from '../lib/types'; +import {formatAnnotationCluster, formatEdgeLabel, formatSpanLabel, type IndexedInput, type Input, ModelInfoMap, SCROLL_SYNC_CSS_CLASS, Spec} from '../lib/types'; import {findSpecKeys, isLitSubtype, makeModifiedInput} from '../lib/utils'; import {GroupService} from '../services/group_service'; import {SelectionService} from '../services/selection_service'; diff --git a/lit_nlp/client/modules/feature_attribution_module.ts b/lit_nlp/client/modules/feature_attribution_module.ts index 2f84cfbd..ae07f221 100644 --- a/lit_nlp/client/modules/feature_attribution_module.ts +++ b/lit_nlp/client/modules/feature_attribution_module.ts @@ -34,7 +34,7 @@ import {IndexedInput, ModelInfoMap} from '../lib/types'; import * as utils from '../lib/utils'; import {findSpecKeys} from '../lib/utils'; import {SignedSalienceCmap} from '../services/color_service'; -import {NumericFeatureBins} from '../services/group_service'; +import {type NumericFeatureBins} from '../services/group_service'; import {AppState, GroupService} from '../services/services'; import {styles as sharedStyles} from '../lib/shared_styles.css'; diff --git a/lit_nlp/client/modules/generated_text_module.ts b/lit_nlp/client/modules/generated_text_module.ts index 0f34bf13..0c0ba927 100644 --- a/lit_nlp/client/modules/generated_text_module.ts +++ b/lit_nlp/client/modules/generated_text_module.ts @@ -26,10 +26,10 @@ import {computed, observable} from 'mobx'; import {LitModule} from '../core/lit_module'; import {styles as visStyles} from '../elements/generated_text_vis.css'; import {LitSwitch} from '../elements/switch'; -import {DiffMode, GeneratedTextResult, GENERATION_TYPES} from '../lib/generated_text_utils'; +import {DiffMode, type GeneratedTextResult, GENERATION_TYPES} from '../lib/generated_text_utils'; import {GeneratedText, GeneratedTextCandidates, LitTypeWithParent, ReferenceScores, ReferenceTexts} from '../lib/lit_types'; import {styles as sharedStyles} from '../lib/shared_styles.css'; -import {IndexedInput, Input, ModelInfoMap, Spec} from '../lib/types'; +import {IndexedInput, type Input, ModelInfoMap, Spec} from '../lib/types'; import {doesOutputSpecContain, findSpecKeys} from '../lib/utils'; import {styles} from './generated_text_module.css'; diff --git a/lit_nlp/client/modules/generator_module.ts b/lit_nlp/client/modules/generator_module.ts index b5fed37a..79af59bb 100644 --- a/lit_nlp/client/modules/generator_module.ts +++ b/lit_nlp/client/modules/generator_module.ts @@ -30,7 +30,7 @@ import {LitModule} from '../core/lit_module'; import {TableData, TableEntry} from '../elements/table'; import {EdgeLabels, FieldMatcher, LitTypeTypesList, SpanLabels} from '../lib/lit_types'; import {styles as sharedStyles} from '../lib/shared_styles.css'; -import {CallConfig, formatForDisplay, IndexedInput, Input, ModelInfoMap, Spec} from '../lib/types'; +import {CallConfig, formatForDisplay, IndexedInput, type Input, ModelInfoMap, Spec} from '../lib/types'; import {cloneSpec, flatten, isLitSubtype} from '../lib/utils'; import {GroupService} from '../services/group_service'; import {SelectionService, SliceService} from '../services/services'; diff --git a/lit_nlp/client/modules/lm_prediction_module.ts b/lit_nlp/client/modules/lm_prediction_module.ts index 94b60a2c..90f7aeab 100644 --- a/lit_nlp/client/modules/lm_prediction_module.ts +++ b/lit_nlp/client/modules/lm_prediction_module.ts @@ -26,7 +26,7 @@ import {computed, observable} from 'mobx'; import {LitModule} from '../core/lit_module'; import {TextSegment, Tokens, TokenTopKPreds} from '../lib/lit_types'; import {styles as sharedStyles} from '../lib/shared_styles.css'; -import {IndexedInput, ModelInfoMap, Spec, TopKResult} from '../lib/types'; +import {type IndexedInput, ModelInfoMap, Spec, TopKResult} from '../lib/types'; import {findMatchingIndices, findSpecKeys, makeModifiedInput, replaceNth} from '../lib/utils'; import {styles} from './lm_prediction_module.css'; diff --git a/lit_nlp/client/modules/metrics_module.ts b/lit_nlp/client/modules/metrics_module.ts index daa00dd9..ebd6e0ce 100644 --- a/lit_nlp/client/modules/metrics_module.ts +++ b/lit_nlp/client/modules/metrics_module.ts @@ -30,7 +30,7 @@ import {MetricBestValue, MetricResult} from '../lib/lit_types'; import {styles as sharedStyles} from '../lib/shared_styles.css'; import {CallConfig, FacetMap, IndexedInput, ModelInfoMap, Spec} from '../lib/types'; import {MetricsResponse, MetricsValues} from '../services/api_service'; -import {GroupService, NumericFeatureBins} from '../services/group_service'; +import {GroupService, type NumericFeatureBins} from '../services/group_service'; import {ClassificationService, SliceService} from '../services/services'; // A dict of metrics type to the MetricsValues for one metric generator. diff --git a/lit_nlp/client/modules/regression_module.ts b/lit_nlp/client/modules/regression_module.ts index e438b7d5..ad85efb0 100644 --- a/lit_nlp/client/modules/regression_module.ts +++ b/lit_nlp/client/modules/regression_module.ts @@ -25,7 +25,7 @@ import {LitModule} from '../core/lit_module'; import {TableData} from '../elements/table'; import {styles as sharedStyles} from '../lib/shared_styles.css'; import {RegressionScore} from '../lib/lit_types'; -import {IndexedInput, ModelInfoMap, RegressionResults, Spec} from '../lib/types'; +import {IndexedInput, ModelInfoMap, type RegressionResults, Spec} from '../lib/types'; import {doesOutputSpecContain, findSpecKeys} from '../lib/utils'; import {CalculatedColumnType} from '../services/data_service'; import {DataService} from '../services/services'; diff --git a/lit_nlp/client/modules/salience_map_module.ts b/lit_nlp/client/modules/salience_map_module.ts index cdd8d473..c05cbae3 100644 --- a/lit_nlp/client/modules/salience_map_module.ts +++ b/lit_nlp/client/modules/salience_map_module.ts @@ -37,7 +37,7 @@ import {InterpreterClick} from '../elements/interpreter_controls'; import {TokenWithWeight} from '../elements/token_chips'; import {FeatureSalience, FieldMatcher, ImageGradients, ImageSalience, LitTypeTypesList, LitTypeWithParent, MulticlassPreds, RegressionScore, Salience, SalienceTargetInfo, TokenGradients, TokenSalience, FrameSalience} from '../lib/lit_types'; import {styles as sharedStyles} from '../lib/shared_styles.css'; -import {CallConfig, IndexedInput, ModelInfoMap, Preds, SCROLL_SYNC_CSS_CLASS, Spec} from '../lib/types'; +import {CallConfig, type IndexedInput, ModelInfoMap, type Preds, SCROLL_SYNC_CSS_CLASS, Spec} from '../lib/types'; import {argmax, cloneSpec, findSpecKeys, makeModifiedInput} from '../lib/utils'; import {SalienceCmap, SignedSalienceCmap, UnsignedSalienceCmap} from '../services/color_service'; import {FocusService} from '../services/focus_service'; diff --git a/lit_nlp/client/modules/sequence_salience_module.ts b/lit_nlp/client/modules/sequence_salience_module.ts index 18aa46e3..a18af539 100644 --- a/lit_nlp/client/modules/sequence_salience_module.ts +++ b/lit_nlp/client/modules/sequence_salience_module.ts @@ -16,10 +16,10 @@ import {LegendType} from '../elements/color_legend'; import {TokenWithWeight} from '../elements/token_chips'; import {SignedSalienceCmap, UnsignedSalienceCmap} from '../lib/colors'; import {SequenceSalienceMap} from '../lib/dtypes'; -import {canonicalizeGenerationResults, GeneratedTextResult, GENERATION_TYPES, getAllTargetOptions, TargetOption} from '../lib/generated_text_utils'; +import {canonicalizeGenerationResults, type GeneratedTextResult, GENERATION_TYPES, getAllTargetOptions, TargetOption} from '../lib/generated_text_utils'; import {Salience} from '../lib/lit_types'; import {styles as sharedStyles} from '../lib/shared_styles.css'; -import {IndexedInput, ModelInfoMap, Spec} from '../lib/types'; +import {type IndexedInput, ModelInfoMap, type Spec} from '../lib/types'; import {sumArray} from '../lib/utils'; import {styles} from './sequence_salience_module.css'; diff --git a/lit_nlp/client/modules/tda_module.ts b/lit_nlp/client/modules/tda_module.ts index 03ef9b51..461d24b3 100644 --- a/lit_nlp/client/modules/tda_module.ts +++ b/lit_nlp/client/modules/tda_module.ts @@ -28,10 +28,10 @@ import {computed, observable} from 'mobx'; import {app} from '../core/app'; import {LitModule} from '../core/lit_module'; import {TableData, TableEntry} from '../elements/table'; -import {canonicalizeGenerationResults, GeneratedTextResult, GENERATION_TYPES, getAllOutputTexts, getFlatTexts} from '../lib/generated_text_utils'; +import {canonicalizeGenerationResults, type GeneratedTextResult, GENERATION_TYPES, getAllOutputTexts, getFlatTexts} from '../lib/generated_text_utils'; import {FieldMatcher, InfluentialExamples, LitTypeWithParent} from '../lib/lit_types'; import {styles as sharedStyles} from '../lib/shared_styles.css'; -import {CallConfig, ComponentInfoMap, IndexedInput, Input, ModelInfoMap, Spec} from '../lib/types'; +import {CallConfig, ComponentInfoMap, type IndexedInput, type Input, ModelInfoMap, Spec} from '../lib/types'; import {cloneSpec, filterToKeys, findSpecKeys, makeModifiedInput} from '../lib/utils'; import {AppState, SelectionService} from '../services/services'; From 2138bd920e72553f9c920ba489962c8649738574 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Wed, 31 Jan 2024 11:21:20 -0800 Subject: [PATCH 18/51] Isolate TF dependencies to lit_nlp.examples module. PiperOrigin-RevId: 603107736 --- lit_nlp/{components => examples/models}/tfx_model.py | 0 lit_nlp/{components => examples/models}/tfx_model_test.py | 2 +- pyproject.toml | 4 ++-- requirements_core.txt | 2 -- requirements_examples.txt | 2 ++ 5 files changed, 5 insertions(+), 5 deletions(-) rename lit_nlp/{components => examples/models}/tfx_model.py (100%) rename lit_nlp/{components => examples/models}/tfx_model_test.py (97%) diff --git a/lit_nlp/components/tfx_model.py b/lit_nlp/examples/models/tfx_model.py similarity index 100% rename from lit_nlp/components/tfx_model.py rename to lit_nlp/examples/models/tfx_model.py diff --git a/lit_nlp/components/tfx_model_test.py b/lit_nlp/examples/models/tfx_model_test.py similarity index 97% rename from lit_nlp/components/tfx_model_test.py rename to lit_nlp/examples/models/tfx_model_test.py index 31dcf628..4b44b2d4 100644 --- a/lit_nlp/components/tfx_model_test.py +++ b/lit_nlp/examples/models/tfx_model_test.py @@ -2,7 +2,7 @@ import tempfile from lit_nlp.api import types as lit_types -from lit_nlp.components import tfx_model +from lit_nlp.examples.models import tfx_model import tensorflow as tf diff --git a/pyproject.toml b/pyproject.toml index 3c8bb401..e3ccdde7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,8 +36,6 @@ dependencies = [ "scipy==1.10.1", "shap==0.37.0", "six==1.16.0", - "tensorflow==2.10.0", - "tensorflow-text==2.10.0", "termcolor==2.3.0", "tqdm==4.64.0", "umap-learn==0.5.1", @@ -82,7 +80,9 @@ keywords = [ # LINT.IfChange examples = [ "gunicorn==20.1.0", + "tensorflow==2.10.0", "tensorflow-datasets==4.8.0", + "tensorflow-text==2.10.0", "torch==2.0.1", "transformers==4.27.1", ] diff --git a/requirements_core.txt b/requirements_core.txt index e3fe7551..98a4b651 100644 --- a/requirements_core.txt +++ b/requirements_core.txt @@ -35,8 +35,6 @@ scikit-learn==1.0.2 scipy==1.10.1 shap==0.37.0 six==1.16.0 -tensorflow==2.10.0 -tensorflow-text==2.10.0 termcolor==2.3.0 tqdm==4.64.0 umap-learn==0.5.1 diff --git a/requirements_examples.txt b/requirements_examples.txt index d3353493..c19c4e1b 100644 --- a/requirements_examples.txt +++ b/requirements_examples.txt @@ -15,7 +15,9 @@ # LINT.IfChange gunicorn==20.1.0 sentencepiece==0.1.99 +tensorflow==2.10.0 tensorflow-datasets==4.8.0 +tensorflow-text==2.10.0 torch==2.0.1 transformers==4.27.1 # LINT.ThenChange(./pyproject.toml) From 27e6901164044c0d33658603369a55600da0b202 Mon Sep 17 00:00:00 2001 From: Bin Du Date: Tue, 6 Feb 2024 08:23:30 -0800 Subject: [PATCH 19/51] GPT2 Generative model. PiperOrigin-RevId: 604654297 --- lit_nlp/examples/models/pretrained_lms.py | 114 ++++++++++++++++++ .../models/pretrained_lms_int_test.py | 17 +++ 2 files changed, 131 insertions(+) diff --git a/lit_nlp/examples/models/pretrained_lms.py b/lit_nlp/examples/models/pretrained_lms.py index 022b35cd..8b72f8fe 100644 --- a/lit_nlp/examples/models/pretrained_lms.py +++ b/lit_nlp/examples/models/pretrained_lms.py @@ -324,3 +324,117 @@ def output_spec(self): align_in="tokens", align_out="tokens") spec[f"layer_{i:d}_avg_embedding"] = lit_types.Embeddings() return spec + + +class GPT2GenerativeModel(lit_model.BatchedModel): + """Wrapper for a Huggingface Transformers GPT-2 model. + + This class loads a tokenizer and model using the Huggingface library and + provides the LIT-required functions to generate text responses given input + prompts. + + Note that the default model generation config is used such that the response + is produced using multinomial sampling. + """ + + @classmethod + def init_spec(cls) -> lit_model.Spec: + return { + "model_name_or_path": lit_types.String(default="gpt2"), + "max_new_tokens": lit_types.Integer(default=50, min_val=1, max_val=500), + "batch_size": lit_types.Integer(default=6, min_val=1, max_val=25), + } + + def __init__( + self, + model=None, + tokenizer=None, + model_name_or_path="gpt2", + max_new_tokens=50, + batch_size=6, + ): + """Constructor for GPT2LanguageModel. + + Note: args "model" and "tokenizer" take priority if both are specified. + Otherwise, "model_name_or_path" is used to initialize the model and + tokenizer. + + Args: + model: an initialized GPT2 model compatible with Tensorflow. + tokenizer: an initialized GPT2 tokenizer. + model_name_or_path: gpt2, gpt2-medium, gpt2-large, gpt2-xl, distilgpt2, + etc. + max_new_tokens: the maximum number of new tokens to generate. + batch_size: the number of items to process per `predict_minibatch` call. + """ + super().__init__() + + if model is not None and tokenizer is not None: + self.model = model + self.tokenizer = tokenizer + else: + # Normally path is a directory; if it's an archive file, download and + # extract to the transformers cache. + if model_name_or_path.endswith(".tar.gz"): + model_name_or_path = file_cache.cached_path( + model_name_or_path, extract_compressed_file=True + ) + + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_name_or_path, use_fast=False + ) + # Set this after init, as if pad_token= is passed to + # AutoTokenizer.from_pretrained() above it will create a new token with + # with id = max_vocab_length and cause out-of-bounds errors in + # the embedding lookup. + self.tokenizer.pad_token = self.tokenizer.eos_token + self.model = transformers.TFAutoModelForCausalLM.from_pretrained( + model_name_or_path + ) + + self.max_new_tokens = max_new_tokens + self.batch_size = batch_size + + ## + # LIT API implementations + def max_minibatch_size(self) -> int: + # The BatchedModel base class handles batching automatically in the + # implementation of predict(), and uses this value as the batch size. + return self.batch_size + + def predict_minibatch(self, inputs): + prompts = [ex["prompt"] for ex in inputs] + encoded_inputs = self.tokenizer.batch_encode_plus( + prompts, + return_tensors="tf", + add_special_tokens=True, + padding="longest", + truncation="longest_first", + ) + outputs = self.model.generate( + encoded_inputs["input_ids"], + max_new_tokens=self.max_new_tokens, + ) + responses = self.tokenizer.batch_decode( + outputs[:, -self.max_new_tokens :], skip_special_tokens=True + ) + embeddings = self.model.transformer.wte(outputs) + return [ + { + "response": responses[i], + "prompt_embeddings": embeddings[i, : -self.max_new_tokens], + "response_embeddings": embeddings[i, -self.max_new_tokens :] + } for i in range(len(outputs)) + ] + + def input_spec(self): + return { + "prompt": lit_types.TextSegment(), + } + + def output_spec(self) -> lit_types.Spec: + return { + "response": lit_types.GeneratedTextCandidates(), + "prompt_embeddings": lit_types.Embeddings(required=False), + "response_embeddings": lit_types.Embeddings(required=False) + } diff --git a/lit_nlp/examples/models/pretrained_lms_int_test.py b/lit_nlp/examples/models/pretrained_lms_int_test.py index f70c6893..62583ef6 100644 --- a/lit_nlp/examples/models/pretrained_lms_int_test.py +++ b/lit_nlp/examples/models/pretrained_lms_int_test.py @@ -31,5 +31,22 @@ def test_gpt2(self): for key in model.output_spec().keys(): self.assertIn(key, model_out[0].keys()) + def test_gpt2_generation(self): + # Run prediction to ensure no failure. + model_path = "https://storage.googleapis.com/what-if-tool-resources/lit-models/gpt2.tar.gz" + model = pretrained_lms.GPT2GenerativeModel(model_name_or_path=model_path) + model_in = [{"prompt": "Today is"}, {"prompt": "What is the color of"}] + model_out = list(model.predict(model_in)) + + # Sanity-check output vs output spec. + self.assertLen(model_out, 2) + for key in model.output_spec().keys(): + self.assertIn(key, model_out[0].keys()) + + # Check that the embedding dimension is the same for prompt and response. + self.assertEqual(model_out[0]["prompt_embeddings"].shape[1], + model_out[0]["response_embeddings"].shape[1]) + + if __name__ == "__main__": absltest.main() From ab294bd3e15675c0e63e5a16ffe4b8cd4941c94f Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Mon, 12 Feb 2024 13:21:58 -0800 Subject: [PATCH 20/51] Utils and helpers for sequence salience, most notably token grouping code. PiperOrigin-RevId: 606346156 --- lit_nlp/client/elements/tooltip.css | 8 +++ lit_nlp/client/elements/tooltip.ts | 1 + lit_nlp/client/lib/token_utils.ts | 53 ++++++++++++++++ lit_nlp/client/lib/token_utils_test.ts | 88 ++++++++++++++++++++++++++ lit_nlp/client/lib/utils.ts | 21 ++++++ lit_nlp/client/lib/utils_test.ts | 13 ++++ lit_nlp/lib/utils.py | 15 ++++- lit_nlp/lib/utils_test.py | 48 +++++++++++++- 8 files changed, 243 insertions(+), 4 deletions(-) create mode 100644 lit_nlp/client/lib/token_utils.ts create mode 100644 lit_nlp/client/lib/token_utils_test.ts diff --git a/lit_nlp/client/elements/tooltip.css b/lit_nlp/client/elements/tooltip.css index 0f1c7038..86ae7243 100644 --- a/lit_nlp/client/elements/tooltip.css +++ b/lit_nlp/client/elements/tooltip.css @@ -10,6 +10,10 @@ * with tooltip positioning. */ --anchor-display-mode: inline-block; + --tooltip-position-left: unset; + --tooltip-position-right: unset; + --tooltip-position-top: unset; + --tooltip-position-bottom: unset; } /* Tooltip */ @@ -34,6 +38,10 @@ font-size: 12px; font-weight: normal; line-height: 16px; + left: var(--tooltip-position-left); + right: var(--tooltip-position-right); + top: var(--tooltip-position-top); + bottom: var(--tooltip-position-bottom); display: -webkit-box; -webkit-line-clamp: 6; diff --git a/lit_nlp/client/elements/tooltip.ts b/lit_nlp/client/elements/tooltip.ts index 14e693f7..62749cd3 100644 --- a/lit_nlp/client/elements/tooltip.ts +++ b/lit_nlp/client/elements/tooltip.ts @@ -71,6 +71,7 @@ export class LitTooltip extends ReactiveElement { 'disabled': this.disabled, }); + // prettier-ignore return html`
${this.content === '' ? '' : html` diff --git a/lit_nlp/client/lib/token_utils.ts b/lit_nlp/client/lib/token_utils.ts new file mode 100644 index 00000000..b81f04ae --- /dev/null +++ b/lit_nlp/client/lib/token_utils.ts @@ -0,0 +1,53 @@ +/** + * @fileoverview Utils for working with tokenized text. + */ + +/** + * Evil underscore used by sentencepiece to replace spaces. + */ +export const SPM_SPACE_SENTINEL = '▁'; + +/** + * Clean SPM text to make it more human-readable. + */ +export function cleanSpmText(text: string): string { + return text.replaceAll(SPM_SPACE_SENTINEL, ' '); +} + +/** + * Use a regex to match segment prefixes. The prefix and anything + * following it (until the next match) are treated as one segment. + */ +export function groupTokensByRegexPrefix( + tokens: string[], + matcher: RegExp, + ): string[][] { + const text = tokens.join(''); + const matches = [...text.matchAll(matcher)]; + + let textCharOffset = 0; // chars into text + let matchIdx = 0; // indices into matches + const groups: string[][] = []; + let acc: string[] = []; + for (let i = 0; i < tokens.length; i++) { + const token = tokens[i]; + const nextMatch = matches[matchIdx]; + + // Look ahead to see if this token intrudes on a match. + // If so, start a new segment before pushing the token. + if (nextMatch !== undefined && + textCharOffset + token.length > nextMatch.index!) { + // Don't push an empty group if the first token is part of a match. + if (acc.length > 0 || groups.length > 0) groups.push(acc); + acc = []; + matchIdx += 1; + } + + // Push the token. + acc.push(token); + textCharOffset += token.length; + } + // Finally, push any open group. + if (acc.length > 0) groups.push(acc); + return groups; +} \ No newline at end of file diff --git a/lit_nlp/client/lib/token_utils_test.ts b/lit_nlp/client/lib/token_utils_test.ts new file mode 100644 index 00000000..76223106 --- /dev/null +++ b/lit_nlp/client/lib/token_utils_test.ts @@ -0,0 +1,88 @@ +/** + * Testing for token_utils.ts + */ + +import 'jasmine'; + +import * as tokenUtils from './token_utils'; + +describe('cleanSpmText test', () => { + it('cleans magic underscores from SPM output', () => { + const text = 'Summarize▁this▁sentence:\n\nOnce▁upon▁a▁time'; + expect(tokenUtils.cleanSpmText(text)) + .toEqual('Summarize this sentence:\n\nOnce upon a time'); + }); +}); + +describe('groupTokensByRegexPrefix test', () => { + [{ + testcaseName: 'groups tokens by word', + tokens: ['Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':'], + regex: /[▁\s]+/g, + expectedGroups: [['Sum', 'mar', 'ize'], ['▁this'], ['▁sent', 'ence', ':']], + }, + { + testcaseName: 'groups tokens by word, handling newlines', + tokens: [ + 'Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':', '\n', '\n', 'Once', + '▁upon', '▁a', '▁time' + ], + // Consecutive newlines should be their own segment. + // Start a new word on the first non-\n afterwards. + regex: /([▁\s]+)|(?<=\n)[^\n]/g, + expectedGroups: [ + ['Sum', 'mar', 'ize'], ['▁this'], ['▁sent', 'ence', ':'], ['\n', '\n'], + ['Once'], ['▁upon'], ['▁a'], ['▁time'] + ], + }, + { + testcaseName: 'groups tokens by sentence, simple version', + tokens: [ + 'Sent', 'ence', '▁one', '.', '▁Sent', 'ence', '▁two', '!', '▁Sent', + 'ence', '▁three', '?' + ], + regex: /(?<=[.?!])[▁\s]+/g, + expectedGroups: [ + ['Sent', 'ence', '▁one', '.'], + ['▁Sent', 'ence', '▁two', '!'], + ['▁Sent', 'ence', '▁three', '?'], + ], + }, + { + testcaseName: 'groups tokens by sentence, handling newlines', + tokens: [ + 'Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':', '\n', '\n', 'Once', + '▁upon', '▁a', '▁time' + ], + // Sentence start is one of: + // - a run of consecutive \n as its own segment + // - any non-\n following \n + // - whitespace or magic underscore following punctuation [.?!] + regex: /(\n+)|((?<=\n)[^\n])|((?<=[.?!])([▁\s]+))/g, + expectedGroups: [ + ['Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':'], ['\n', '\n'], + ['Once', '▁upon', '▁a', '▁time'] + ], + }, + { + testcaseName: 'groups tokens by line', + tokens: [ + 'Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':', '\n', '\n', 'Once', + '▁upon', '▁a', '▁time' + ], + // Line start is either: + // - a run of consecutive \n as its own segment + // - any non-\n following \n + regex: /(\n+)|([^\n]+)/g, + expectedGroups: [ + ['Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':'], ['\n', '\n'], + ['Once', '▁upon', '▁a', '▁time'] + ], + }, + ].forEach(({testcaseName, tokens, regex, expectedGroups}) => { + it(testcaseName, () => { + const groups = tokenUtils.groupTokensByRegexPrefix(tokens, regex); + expect(groups).toEqual(expectedGroups); + }); + }); +}); \ No newline at end of file diff --git a/lit_nlp/client/lib/utils.ts b/lit_nlp/client/lib/utils.ts index 99ca1cef..b35d782c 100644 --- a/lit_nlp/client/lib/utils.ts +++ b/lit_nlp/client/lib/utils.ts @@ -302,6 +302,27 @@ export function cumSumArray(array: number[]) { return newArray; } +/** + * Group elements of one list to match the partitions of another. + * + * Example: + * groupAlike([0, 1, 2, 3, 4, 5], [['a', 'b'], ['c'], ['d', 'e', 'f']]) + * + * Should return: [[0, 1], [2], [3, 4, 5]] + */ +export function groupAlike(items: T[], groups: unknown[][]): T[][] { + const offsets = [0, ...cumSumArray(groups.map(g => g.length))]; + if (offsets.at(-1) !== items.length) { + throw new Error(`Total length of groups (${ + offsets.at(-1)}) !== number of items (${items.length}).`); + } + const ret = []; + for (let i = 0; i < groups.length; i++) { + ret.push(items.slice(offsets[i], offsets[i + 1])); + } + return ret; +} + /** * Python-style array comparison. * Compare on first element, then second, and so on until a mismatch is found. diff --git a/lit_nlp/client/lib/utils_test.ts b/lit_nlp/client/lib/utils_test.ts index eb712879..d4cec4f3 100644 --- a/lit_nlp/client/lib/utils_test.ts +++ b/lit_nlp/client/lib/utils_test.ts @@ -436,6 +436,19 @@ describe('cumSumArray test', () => { }); }); +describe('groupAlike test', () => { + it('groups items', () => { + const result = utils.groupAlike( + [0, 1, 2, 3, 4, 5], [['a', 'b'], ['c'], ['d', 'e', 'f']]); + expect(result).toEqual([[0, 1], [2], [3, 4, 5]]); + }); + it('raises an error if lengths do not match', () => { + expect(() => utils.groupAlike([0, 1, 2, 3, 4, 5], [['a', 'b'], ['c']])) + .toThrow( + new Error('Total length of groups (3) !== number of items (6).')); + }); +}); + describe('compareArrays test', () => { it('Correctly tests normal comparison', () => { // Shorter arrays. diff --git a/lit_nlp/lib/utils.py b/lit_nlp/lib/utils.py index 74b8986f..4351465d 100644 --- a/lit_nlp/lib/utils.py +++ b/lit_nlp/lib/utils.py @@ -198,9 +198,20 @@ def unbatch_preds( yield {key: value[i] for key, value in preds.items()} -def pad1d(arr: list[T], min_len: int, pad_val: T) -> list[T]: +def pad1d( + arr: list[T], + min_len: int, + pad_val: T, + pad_left: bool = False, + max_len: int | None = None, +) -> list[T]: """Pad a list to the target length.""" - return arr + [pad_val] * max(0, min_len - len(arr)) + if pad_left: + padded = [pad_val] * max(0, min_len - len(arr)) + arr + return padded[-max_len:] if max_len is not None else padded + else: + padded = arr + [pad_val] * max(0, min_len - len(arr)) + return padded[:max_len] if max_len is not None else padded def find_all_combinations( diff --git a/lit_nlp/lib/utils_test.py b/lit_nlp/lib/utils_test.py index c920dcb4..5646181f 100644 --- a/lit_nlp/lib/utils_test.py +++ b/lit_nlp/lib/utils_test.py @@ -252,11 +252,55 @@ def test_batch_inputs_raises( pad_val="", expected=["one", "two", "three", "", ""], ), + dict( + testcase_name="truncate_max_len", + inputs=[1, 2, 3, 4, 5], + min_len=3, + pad_val=0, + max_len=3, + expected=[1, 2, 3], + ), + dict( + testcase_name="pad_left", + inputs=[1, 2, 3], + min_len=5, + pad_val=0, + pad_left=True, + expected=[0, 0, 1, 2, 3], + ), + dict( + testcase_name="truncate_max_len_left", + inputs=[1, 2, 3, 4, 5], + min_len=3, + pad_val=0, + pad_left=True, + max_len=3, + expected=[3, 4, 5], + ), + dict( + testcase_name="pad_left_with_strings", + inputs=["one", "two", "three"], + min_len=5, + pad_val="", + pad_left=True, + expected=["", "", "one", "two", "three"], + ), ) def test_pad1d( - self, inputs: list[T], min_len: T, pad_val: T, expected: list[T] + self, + inputs: list[T], + min_len: T, + pad_val: T, + expected: list[T], + pad_left: bool = False, + max_len: int | None = None, ): - self.assertEqual(utils.pad1d(inputs, min_len, pad_val), expected) + self.assertEqual( + utils.pad1d( + inputs, min_len, pad_val, pad_left=pad_left, max_len=max_len + ), + expected, + ) @parameterized.named_parameters( dict( From 5cffc4d933e611587b00c25861c911d5f734fa22 Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Mon, 12 Feb 2024 15:54:57 -0800 Subject: [PATCH 21/51] Misc frontend changes to support sequence salience. - Updates to - accepts labels via slots, so don't need to be just text - Specify default left/right split for three-panel layouts PiperOrigin-RevId: 606392017 --- lit_nlp/api/layout.py | 1 + lit_nlp/client/core/modules.ts | 14 +++- lit_nlp/client/elements/switch.ts | 9 ++- lit_nlp/client/elements/token_chips.css | 47 ++++++++++++ lit_nlp/client/elements/token_chips.ts | 85 ++++++++++++++++++--- lit_nlp/client/elements/token_chips_test.ts | 68 ++++++++++++++++- lit_nlp/client/lib/types.ts | 4 +- 7 files changed, 212 insertions(+), 16 deletions(-) diff --git a/lit_nlp/api/layout.py b/lit_nlp/api/layout.py index fe65d412..637b3f19 100644 --- a/lit_nlp/api/layout.py +++ b/lit_nlp/api/layout.py @@ -91,6 +91,7 @@ class ModuleConfig(dtypes.DataTuple): class LayoutSettings(dtypes.DataTuple): hideToolbar: bool = False mainHeight: int = 45 + leftWidth: int = 50 centerPage: bool = False diff --git a/lit_nlp/client/core/modules.ts b/lit_nlp/client/core/modules.ts index d4503776..ffb53b85 100644 --- a/lit_nlp/client/core/modules.ts +++ b/lit_nlp/client/core/modules.ts @@ -151,6 +151,12 @@ export class LitModules extends ReactiveElement { (mainHeight) => { if (mainHeight != null) {this.upperHeight = `${mainHeight}%`;} }); + this.reactImmediately( + () => this.modulesService.getSetting('leftWidth'), (leftWidth) => { + if (leftWidth != null) { + this.leftColumnWidth = `${leftWidth}%`; + } + }); document.addEventListener('keydown', (e: KeyboardEvent) => { if (e.key === 'Escape') { @@ -422,7 +428,13 @@ export class LitModules extends ReactiveElement { const columnSeparatorDoubleClick = (event: DragEvent) => { event.stopPropagation(); event.preventDefault(); - this.leftColumnWidth = LEFT_COLUMN_DEFAULT_WIDTH; + const layoutDefaultLeftWidth = + this.modulesService.getSetting('leftWidth'); + if (layoutDefaultLeftWidth != null) { + this.leftColumnWidth = `${layoutDefaultLeftWidth}%`; + } else { + this.leftColumnWidth = LEFT_COLUMN_DEFAULT_WIDTH; + } }; const leftColumnStyles = styleMap({ diff --git a/lit_nlp/client/elements/switch.ts b/lit_nlp/client/elements/switch.ts index 7bd383e4..74118e4f 100644 --- a/lit_nlp/client/elements/switch.ts +++ b/lit_nlp/client/elements/switch.ts @@ -85,12 +85,17 @@ export class LitSwitch extends LitElement { 'selected': this.selected }); + // prettier-ignore return html`
-
${this.labelLeft}
+
+ ${this.labelLeft} +
-
${this.labelRight}
+
+ ${this.labelRight} +
`; } diff --git a/lit_nlp/client/elements/token_chips.css b/lit_nlp/client/elements/token_chips.css index 8cc33112..bea68608 100644 --- a/lit_nlp/client/elements/token_chips.css +++ b/lit_nlp/client/elements/token_chips.css @@ -52,4 +52,51 @@ .pre-wrap { white-space: pre-wrap; +} + +.row-break { + flex-basis: 100%; + height: 0; +} + +.word-spacer { + width: 1em; +} + +.tokens-holder-dense .word-spacer { + width: 0.5em; +} + +/* block mode */ +.tokens-holder-display-block { + display: block; + font-size: 0; /* hack to get zero spacing between elements */ + line-height: 22px; +} + +.tokens-holder-display-block > * { + /* TODO: set this for all modes? */ + font-size: 13px; /* restore standard font size */ +} + +.tokens-holder-display-block .salient-token { + display: inline; + min-height: 1lh; + vertical-align: baseline; +} + +.tokens-holder-display-block.tokens-holder-dense .salient-token span { + /* hack to remove extra whitespace. ugh. */ + margin-right: -0.445ch; +} + +.tokens-holder-display-block .word-spacer { + display: inline; + vertical-align: baseline; + white-space: pre-wrap; +} + +.tokens-holder-display-block lit-tooltip { + --anchor-display-mode: 'inline'; + --tooltip-position-left: 0; } \ No newline at end of file diff --git a/lit_nlp/client/elements/token_chips.ts b/lit_nlp/client/elements/token_chips.ts index bafff398..2e14ffdc 100644 --- a/lit_nlp/client/elements/token_chips.ts +++ b/lit_nlp/client/elements/token_chips.ts @@ -35,9 +35,9 @@ export interface TokenWithWeight { weight: number; selected?: boolean; pinned?: boolean; - onClick?: (e: Event) => void; - onMouseover?: (e: Event) => void; - onMouseout?: (e: Event) => void; + onClick?: (e: MouseEvent) => void; + onMouseover?: (e: MouseEvent) => void; + onMouseout?: (e: MouseEvent) => void; disableHover?: boolean; forceShowTooltip?: boolean; } @@ -50,9 +50,33 @@ export class TokenChips extends LitElement { // List of tokens to display @property({type: Array}) tokensWithWeights: TokenWithWeight[] = []; @property({type: Object}) cmap: SalienceCmap = new UnsignedSalienceCmap(); - @property({type: String}) - tokenGroupTitle?: string; // can be used for gradKey + // Group title, such as the name of the active salience method. + @property({type: String}) tokenGroupTitle?: string; + /** + * Dense mode, for less padding and smaller margins around each chip. + */ @property({type: Boolean}) dense = false; + /** + * Block mode uses display: block and inline elements for chips, instead of + * a flex-row layout. This allows chips to flow across line breaks, behaving + * more like elements and giving a much better experience for larger + * segments like sentences. However, this comes at the cost of more spacing + * artifacts and occasional issues with tooltip positioning. + */ + @property({type: Boolean}) displayBlock = false; + /** + * breakNewlines removes \n at the beginning or end of a segment and inserts + * explicit row break elements instead. Improves readability in many settings, + * at the cost of "faithfulness" to the original token text. + */ + @property({type: Boolean}) breakNewlines = false; + /** + * preSpace removes a leading space from a token and inserts an explicit + * spacer element instead. Improves readability in many settings by giving + * natural space between the highlight area for adjacent words, albeit at the + * cost of hiding where the actual spaces are in the tokenization. + */ + @property({type: Boolean}) preSpace = false; static override get styles() { return [sharedStyles, styles]; @@ -71,17 +95,56 @@ export class TokenChips extends LitElement { 'color': this.cmap.textCmap(tokenInfo.weight), }); - // clang-format off + let tokenText = tokenInfo.token; + + let preSpace = false; + if (this.preSpace && tokenText.startsWith(' ')) { + preSpace = true; + tokenText = tokenText.slice(1); + } + + // TODO(b/324955623): render a gray '⏎' for newlines? + // Maybe make this a toggleable option, as it can be distracting. + // TODO(b/324955623): better rendering for multiple newlines, like \n\n\n ? + // Consider adding an extra ' ' on each line. + + let preBreak = false; + let postBreak = false; + if (this.breakNewlines) { + // Logic: + // - \n : post-break, so blank space goes on previous line + // - foo\n : post-break + // - \nfoo : pre-break + // - \n\n : pre- and post-break, shows a space on its own line + // - \n\n\n : pre- and post-break, two lines with only spaces + if (tokenText.endsWith('\n')) { + // Prefer post-break because this puts the blank space on the end of the + // previous line, rather than creating an awkward indent on the next + // one. + tokenText = tokenText.slice(0, -1) + ' '; + postBreak = true; + } + if (tokenText.startsWith('\n')) { + // Pre-break only if \n precedes some other text. + preBreak = true; + tokenText = ' ' + tokenText.slice(1); + } + } + + // prettier-ignore return html` + ${preBreak ? html`
` : null} + ${preSpace ? html`
` : null}
- ${tokenInfo.token} + ${tokenText} -
`; - // clang-format on +
+ ${postBreak ? html`
` : null} + `; } override render() { @@ -92,9 +155,10 @@ export class TokenChips extends LitElement { const holderClass = classMap({ 'tokens-holder': true, 'tokens-holder-dense': this.dense, + 'tokens-holder-display-block': this.displayBlock, }); - // clang-format off + // prettier-ignore return html`
${this.tokenGroupTitle ? this.tokenGroupTitle : ''} @@ -102,7 +166,6 @@ export class TokenChips extends LitElement { ${tokensDOM}
`; - // clang-format on } } diff --git a/lit_nlp/client/elements/token_chips_test.ts b/lit_nlp/client/elements/token_chips_test.ts index 03c36598..1cdb8a87 100644 --- a/lit_nlp/client/elements/token_chips_test.ts +++ b/lit_nlp/client/elements/token_chips_test.ts @@ -31,7 +31,21 @@ const TESTDATA: Array<{tokensWithWeights: TokenWithWeight[]}> = [ {token: 'hello', weight: 0.7, selected: true, pinned: true}, {token: 'world', weight: 0.3} ], - } + }, + { + // for testing preSpace mode + tokensWithWeights: [ + {token: 'foo', weight: 0.7, selected: true, pinned: true}, + {token: ' bar', weight: 0.3}, {token: 'baz', weight: 0.5} + ], + }, + { + // for testing breakNewlines mode + tokensWithWeights: [ + {token: 'foo', weight: 0.7}, {token: '\nbar', weight: 0.3}, + {token: '\n\n', weight: 0.1}, {token: 'baz\n', weight: 0.5} + ], + }, ]; describe('token chips test', () => { @@ -60,6 +74,58 @@ describe('token chips test', () => { expect(tokenElements[0].children[0]).toBeInstanceOf(LitTooltip); }); + it('should break spaces in preSpace mode', async () => { + tokenChips.preSpace = true; + await tokenChips.updateComplete; + + const tokenElements = + tokenChips.renderRoot.querySelectorAll( + 'div.salient-token'); + expect(tokenElements.length).toEqual(tokensWithWeights.length); + for (let i = 0; i < tokenElements.length; i++) { + const elem = tokenElements[i]; + const expectedToken = tokensWithWeights[i].token; + if (expectedToken.startsWith(' ')) { + // Space moved to a word spacer. + expect(elem.innerText).toEqual(expectedToken.slice(1)); + expect(elem.previousElementSibling?.classList ?? []) + .toContain('word-spacer'); + } else { + // Space intact, no word spacer. + expect(elem.innerText).toEqual(expectedToken); + if (i > 0) { + expect(elem.previousElementSibling?.classList ?? []) + .toContain('salient-token'); + } + } + } + }); + + it('should break newlines in breakNewlines mode', async () => { + tokenChips.breakNewlines = true; + await tokenChips.updateComplete; + + const tokenElements = + tokenChips.renderRoot.querySelectorAll( + 'div.salient-token'); + expect(tokenElements.length).toEqual(tokensWithWeights.length); + for (let i = 0; i < tokenElements.length; i++) { + const elem = tokenElements[i]; + let expectedToken = tokensWithWeights[i].token; + if (expectedToken.endsWith('\n')) { + expectedToken = expectedToken.slice(0, -1) + ' '; + expect(elem.nextElementSibling?.classList ?? []) + .toContain('row-break'); + } + if (expectedToken.startsWith('\n')) { + expectedToken = ' ' + expectedToken.slice(1); + expect(elem.previousElementSibling?.classList ?? []) + .toContain('row-break'); + } + expect(elem.innerText).toEqual(expectedToken); + } + }); + it('should mark a selected token', async () => { const tokenElements = tokenChips.renderRoot.querySelectorAll( diff --git a/lit_nlp/client/lib/types.ts b/lit_nlp/client/lib/types.ts index 5f60fff0..c1bb8794 100644 --- a/lit_nlp/client/lib/types.ts +++ b/lit_nlp/client/lib/types.ts @@ -410,8 +410,10 @@ export declare interface LitCanonicalLayout { */ export declare interface LayoutSettings { hideToolbar?: boolean; - /** The default height of #upper-right, as a percentage of the parent. */ + /** The default height of the 'upper' section, as a percentage. */ mainHeight?: number; + /** The default width of the 'left' section, as a percentage. */ + leftWidth?: number; centerPage?: boolean; } From 80cf699f92cd77d58cb2a2a60b9314010b1f336c Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Tue, 13 Feb 2024 15:54:53 -0800 Subject: [PATCH 22/51] Sequence salience for a decoder-only LM. Demo using GPT-2. PiperOrigin-RevId: 606773805 --- lit_nlp/api/layout.py | 3 +- lit_nlp/client/modules/lm_salience_module.css | 98 +++ lit_nlp/client/modules/lm_salience_module.ts | 833 ++++++++++++++++++ lit_nlp/examples/datasets/lm.py | 51 +- .../examples/datasets/prompt_examples.jsonl | 9 + lit_nlp/examples/lm_salience_demo.py | 177 ++++ lit_nlp/examples/models/pretrained_lms.py | 334 ++++++- .../models/pretrained_lms_int_test.py | 6 +- 8 files changed, 1465 insertions(+), 46 deletions(-) create mode 100644 lit_nlp/client/modules/lm_salience_module.css create mode 100644 lit_nlp/client/modules/lm_salience_module.ts create mode 100644 lit_nlp/examples/datasets/prompt_examples.jsonl create mode 100644 lit_nlp/examples/lm_salience_demo.py diff --git a/lit_nlp/api/layout.py b/lit_nlp/api/layout.py index 637b3f19..ce0ab4c1 100644 --- a/lit_nlp/api/layout.py +++ b/lit_nlp/api/layout.py @@ -21,7 +21,6 @@ from lit_nlp.api import dtypes -# LINT.IfChange # pylint: disable=invalid-name @enum.unique class LitModuleName(dtypes.EnumSerializableAsValues, enum.Enum): @@ -48,6 +47,7 @@ class LitModuleName(dtypes.EnumSerializableAsValues, enum.Enum): GeneratedTextModule = 'generated-text-module' GeneratorModule = 'generator-module' LanguageModelPredictionModule = 'lm-prediction-module' + LMSalienceModule = 'lm-salience-module' MetricsModule = 'metrics-module' MultilabelModule = 'multilabel-module' PdpModule = 'pdp-module' @@ -68,6 +68,7 @@ def __call__(self, **kw): return ModuleConfig(self.value, **kw) +# LINT.IfChange # TODO(lit-dev): consider making modules subclass this instead of LitModuleName. @attr.s(auto_attribs=True) class ModuleConfig(dtypes.DataTuple): diff --git a/lit_nlp/client/modules/lm_salience_module.css b/lit_nlp/client/modules/lm_salience_module.css new file mode 100644 index 00000000..22192870 --- /dev/null +++ b/lit_nlp/client/modules/lm_salience_module.css @@ -0,0 +1,98 @@ +.flex-column { + display: flex; + flex-direction: column; +} + +.chip-container { + padding: 8px; +} + +.chip-container-dense { + padding: 8px; +} + +.pre-wrap { + white-space: pre-wrap; +} + +.gray-text { + color: var(--lit-neutral-400); +} + +.target-info-line { + white-space: nowrap; + text-overflow: ellipsis; + overflow-x: hidden; +} + +lit-switch .icon-button { + vertical-align: middle; +} + +/** + * Module controls + */ +.module-toolbar { + border-bottom: 1px solid #dadce0; + box-sizing: border-box; + justify-content: space-between; +} + +.module-footer { + justify-content: space-between; +} + +.controls-group { + display: flex; + flex-direction: row; + align-items: center; + margin: 0 4px; + gap: 4px; +} + +.controls-group[disabled] { + color: rgb(60, 64, 67); + opacity: 0.38; + pointer-events: none; +} + +/* Allow contents to consume available space, but not to cause line wrapping. */ +.controls-group-variable { + flex: 1; + overflow-x: clip; + margin-right: 8px; +} + +.controls-group-variable > label { + min-width: 45px; +} + +.controls-group-variable .dropdown { + max-width: calc(100% - 45px); +} + +.vertical-separator { + background: #dadce0; + width: 2px; + height: 1.2rem; + padding: 0; + margin: 0 8px; +} + +/* Allow wrap. TODO move this to shared_styles as an option? */ +.module-footer-wrappable { + flex-wrap: wrap; + /* line-height: 30px; */ /* causes alignment issues */ + height: unset; + min-height: 36px; +} + +.module-footer > * { min-width: 0; } + +.controls-group > * { min-width: 0; } + +color-legend { + /* extra space to keep other controls from jumping when legend changes */ + /* width: 400px; */ + margin-right: 16px; +} \ No newline at end of file diff --git a/lit_nlp/client/modules/lm_salience_module.ts b/lit_nlp/client/modules/lm_salience_module.ts new file mode 100644 index 00000000..b2deb9c7 --- /dev/null +++ b/lit_nlp/client/modules/lm_salience_module.ts @@ -0,0 +1,833 @@ +/** + * @fileoverview Custom viz module for causal LM salience. + */ + +import '@material/mwc-icon'; +import '../elements/color_legend'; +import '../elements/numeric_input'; +import '../elements/fused_button_bar'; + +import {css, html} from 'lit'; +// tslint:disable:no-new-decorators +import {customElement} from 'lit/decorators.js'; +import {computed, observable, toJS} from 'mobx'; + +import {LitModule} from '../core/lit_module'; +import {LegendType} from '../elements/color_legend'; +import {NumericInput as LitNumericInput} from '../elements/numeric_input'; +import {TokenChips, TokenWithWeight} from '../elements/token_chips'; +import {SalienceCmap, SignedSalienceCmap, UnsignedSalienceCmap,} from '../lib/colors'; +import {GENERATION_TYPES, getAllTargetOptions, TargetOption, TargetSource} from '../lib/generated_text_utils'; +import {LitType, LitTypeTypesList, Tokens, TokenScores} from '../lib/lit_types'; +import {styles as sharedStyles} from '../lib/shared_styles.css'; +import {cleanSpmText, groupTokensByRegexPrefix} from '../lib/token_utils'; +import {type IndexedInput, type Preds, SCROLL_SYNC_CSS_CLASS, type Spec} from '../lib/types'; +import {cumSumArray, filterToKeys, findSpecKeys, groupAlike, makeModifiedInput, sumArray} from '../lib/utils'; + +import {styles} from './lm_salience_module.css'; + +/** + * Max of absolute value + */ +export function maxAbs(vals: number[]): number { + return Math.max(...vals.map(Math.abs)); +} + +enum SegmentationMode { + TOKENS = 'Tokens', + WORDS = 'Words', + SENTENCES = 'Sentences', + LINES = 'Lines', + // TODO(b/324961811): add phrase or clause chunking? + // TODO(b/324961803): add custom regex? +} + +const LEGEND_INFO_TITLE_SIGNED = + 'Salience is relative to the model\'s prediction of a token. A positive ' + + 'score (more green) for a token means that token influenced the model to ' + + 'predict the selected target, whereas a negaitve score (more pink) means ' + + 'the token influenced the model to not predict the selected target.'; + +const LEGEND_INFO_TITLE_UNSIGNED = + 'Salience is relative to the model\'s prediction of a token. A larger ' + + 'score (more purple) for a token means that token was more influential ' + + 'on the model\'s prediction of the selected target.'; + +/** + * A convenience implementation of LitModule for single model, single example + * use. Implements some standard boilerplate to fetch model predictions. + * + * Subclass should still register this with @customElement, and add to the + * HTMLElementTagNameMap, we well as implement: + * - static template = ... + * - override renderImpl() {...} + * + * And optionally: + * - static styles() {...} + * - static override shouldDisplayModule() {...} + * + * If subclass implements firstUpdated(), be sure to call super.firstUpdated() + * to register the reaction to the primary selection. + */ +export class SingleExampleSingleModelModule extends LitModule { + static override duplicateForExampleComparison = true; + static override duplicateForModelComparison = true; + + // Override this to request only specific types. + protected predsTypes: LitTypeTypesList = [LitType]; + + @observable protected currentData?: IndexedInput; + @observable protected currentPreds?: Preds; + + // Override this for any post-processing. + protected postprocessPreds(input: IndexedInput, preds: Preds): Preds { + return preds; + } + + protected resetState() { + this.currentData = undefined; + this.currentPreds = undefined; + } + + protected async updateToSelection(input: IndexedInput|null) { + this.resetState(); + + if (input == null) return; + + // Before waiting for the backend call, update data. + // currentPreds should already be cleared by the resetState() call above. + this.currentData = input; + + const promise = this.apiService.getPreds( + [input], + this.model, + this.appState.currentDataset, + this.predsTypes, + [], + 'Getting model predictions.', + ); + const results = await this.loadLatest('modelPreds', promise); + if (results === null) return; + + const preds = this.postprocessPreds(input, results[0]); + + // Update data again, in case selection changed rapidly. + this.currentData = input; + this.currentPreds = preds; + } + + override firstUpdated() { + this.reactImmediately( + () => this.selectionService.primarySelectedInputData, + (data) => { + this.updateToSelection(data); + }, + ); + } +} + +/** + * Custom styled version of for rendering LM salience tokens. + */ +@customElement('lm-salience-chips') +class LMSalienceChips extends TokenChips { + static override get styles() { + return [ + ...TokenChips.styles, + css` + .salient-token { + padding: 1px 3px; /* wider horizontally */ + margin: 2px; + min-width: 4px; /* easier to see whitespace tokens */ + } + .tokens-holder:not(.tokens-holder-dense) .salient-token:not(.selected) { + --token-outline-color: var(--lit-neutral-300); /* outline in non-dense mode */ + } + .tokens-holder-display-block .salient-token { + padding: 3px 0; + margin: 0; + margin-right: 4px; + } + .salient-token.selected { + --token-outline-color: var(--lit-mage-700); + box-shadow: 0px 0px 3px var(--token-outline-color); + } + .tokens-holder-dense .salient-token { + margin: 2px 0px; /* vertical spacing only */ + min-width: 6px; /* not too small. Check if this causes issues inside words. */ + } + .tokens-holder-dense .salient-token.selected { + outline: 2px solid var(--token-outline-color); + border: 0; + box-shadow: unset; + /* TODO see if we can get away from z-index here */ + z-index: 10; + } + `, + ]; + } +} + +interface SalienceResults { + [method: string]: number[]; +} + +// Sentinel value because mobx doesn't react directly to a promise completing. +const REQUEST_PENDING: unique symbol = Symbol('REQUEST_PENDING'); + +/** LIT module for model output. */ +@customElement('lm-salience-module') +export class LMSalienceModule extends SingleExampleSingleModelModule { + static override title = 'LM Salience'; + static override numCols = 6; // 60% of screen width if DataTable on left + static override duplicateAsRow = true; + // prettier-ignore + static override template = ( + model: string, + selectionServiceIndex: number, + shouldReact: number, + ) => html` + `; + + static override get styles() { + return [sharedStyles, styles]; + } + + // For generation model. For salience, see updateSalience() below. + override predsTypes = GENERATION_TYPES; + + @observable + private segmentationMode: SegmentationMode = SegmentationMode.WORDS; + // TODO(b/324959547): get default from spec + @observable private selectedSalienceMethod? = 'grad_l2'; + @observable private cmapGamma = 1.0; + @observable private denseView = true; + @observable private showSelfSalience = false; + + @observable.ref private currentTokens: string[] = []; + @observable.ref private salienceTargetOptions: TargetOption[] = []; + @observable private salienceTargetString = ''; + @observable.ref private targetSegmentSpan?: [number, number] = undefined; + + + /** + * Cache for salience results for different target spans. + * Because computing salience can be slow and we don't want to lock the + * frontend, we use this cache as an intermediary between the API calls + * (updateSalience) and the rendering logic. API calls are asynchronous with + * updates and populate this cache with their results; the rendering logic + * then observes this cache and renders only the result with the current + * selected span. + * + * Each cache entry can have three states: + * - undefined: we haven't seen it yet, so updateSalience will issue a backend + * call. + * - REQUEST_PENDING: sentinel value, set while a backend call is in progress. + * - Otherwise, will contain a SalienceResults object with results for that + * key. + */ + @observable + private salienceResultCache: + {[targetKey: string]: SalienceResults|(typeof REQUEST_PENDING)} = {}; + + @computed + get salienceModelName(): string { + return `_${this.model}_salience`; + } + + @computed + get tokenizerModelName(): string { + // TODO: fall back to salience model if not available? + return `_${this.model}_tokenizer`; + } + + private resetTargetSpan() { + this.targetSegmentSpan = undefined; + } + + override resetState() { + // Generation & target string selection + super.resetState(); // currentData and currentPreds + this.salienceTargetOptions = []; + this.salienceTargetString = ''; + // Tokens and selected target span + this.currentTokens = []; + this.resetTargetSpan(); + // Salience results + this.salienceResultCache = {}; + } + + // Get generations; populate this.currentPreds + protected override async updateToSelection(input: IndexedInput|null) { + await super.updateToSelection(input); + this.resetTargetSpan(); + + const dataSpec = this.appState.currentDatasetSpec; + const outputSpec = this.appState.getModelSpec(this.model).output; + this.salienceTargetOptions = getAllTargetOptions( + dataSpec, + outputSpec, + this.currentData, + this.currentPreds, + ); + this.salienceTargetString = this.salienceTargetOptions[0]?.text ?? ''; + } + + // Modified input with selected target sequence. Use this for tokens and + // salience. + @computed + get modifiedData(): IndexedInput|null { + if (this.currentData == null) return null; + return makeModifiedInput( + this.currentData, {'target': this.salienceTargetString}); + } + + @computed + get currentTokenGroups(): string[][] { + if (this.segmentationMode === SegmentationMode.TOKENS) { + return this.currentTokens.map(t => [t]); + } else if (this.segmentationMode === SegmentationMode.WORDS) { + // Word start is either: + // - whitespace or magic underscore + // - any non-\n following \n + // The latter is needed to avoid forming weird segments like '\n\nfoo'; + // by using the lookbehind, this will end up as ['\n\n', 'foo'] + return groupTokensByRegexPrefix( + this.currentTokens, /([▁\s]+)|(?<=\n)[^\n]/g); + } else if (this.segmentationMode === SegmentationMode.SENTENCES) { + // Sentence start is one of: + // - a run of consecutive \n as its own segment + // - any non-\n following \n + // - whitespace or magic underscore following punctuation [.?!] + return groupTokensByRegexPrefix( + this.currentTokens, /(\n+)|((?<=\n)[^\n])|((?<=[.?!])([▁\s]+))/g); + } else if (this.segmentationMode === SegmentationMode.LINES) { + // Line start is either: + // - a run of consecutive \n as its own segment + // - any non-\n following \n + return groupTokensByRegexPrefix(this.currentTokens, /(\n+)|([^\n]+)/g); + } else { + throw new Error( + `Unsupported segmentation mode ${this.segmentationMode}.`); + } + } + + /** + * Segment offsets, as token indices. + * Segment i corresponds to tokens offsets[i]:offsets[i+1] + */ + @computed + get currentSegmentOffsets(): number[] { + return [0, ...cumSumArray(this.currentTokenGroups.map(g => g.length))]; + } + + @computed + get targetTokenSpan(): number[]|undefined { + if (this.targetSegmentSpan === undefined) return undefined; + const [segmentStart, segmentEnd] = this.targetSegmentSpan; + const offsets = this.currentSegmentOffsets; + return [offsets[segmentStart], offsets[segmentEnd]]; + } + + @computed + get currentSegmentTexts(): string[] { + const segments = this.currentTokenGroups.map(tokens => tokens.join('')); + // Tokens in non-dense view should show exact tokenization, including magic + // underscores. + if (this.segmentationMode === SegmentationMode.TOKENS && !this.denseView) { + return segments; + } + // Otherwise, clean up underscores. + return segments.map(cleanSpmText); + } + + @computed + get salienceSpecInfo(): Spec { + const outputSpec = + this.appState.getModelSpec(this.salienceModelName).output; + const salienceKeys = findSpecKeys(outputSpec, TokenScores); + return filterToKeys(outputSpec, salienceKeys); + } + + /** + * Salience for active model, for all tokens. + */ + @computed + get activeTokenSalience(): number[]|undefined { + if (this.targetTokenSpan === undefined) return undefined; + + const cachedResult = + this.salienceResultCache[this.spanToKey(this.targetTokenSpan)]; + if (cachedResult === undefined || cachedResult === REQUEST_PENDING) { + return undefined; + } + + if (this.selectedSalienceMethod === undefined) { + return undefined; + } + + return cachedResult[this.selectedSalienceMethod]; + } + + /** + * Salience for active mode, for current segments. + */ + @computed + get activeSalience(): number[]|undefined { + if (this.activeTokenSalience === undefined) return undefined; + const groupedSalience = + groupAlike(this.activeTokenSalience, this.currentTokenGroups); + return groupedSalience.map(sumArray); + } + + @computed + get cmapRange(): number { + if (this.activeSalience === undefined) return 1; + // If nothing focused, use the max over all (absolute) scores. + return Math.max(1e-3, maxAbs(this.activeSalience)); + } + + @computed + get signedSalienceCmap() { + return new SignedSalienceCmap(this.cmapGamma, [ + -1 * this.cmapRange, + this.cmapRange, + ]); + } + + @computed + get unsignedSalienceCmap() { + return new UnsignedSalienceCmap(this.cmapGamma, [0, this.cmapRange]); + } + + @computed + get cmap(): SalienceCmap { + // TODO(b/324959547): get signed/unsigned info from spec. + // May need to add a signed= bit to the TokenScores type, + // or use the TokenSalience type. + return this.selectedSalienceMethod === 'grad_dot_input' ? + this.signedSalienceCmap : + this.unsignedSalienceCmap; + } + + spanToKey(span: number[]) { + return `${span[0]}:${span[1]}`; + } + + async updateTokens(input: IndexedInput|null) { + if (input == null) { + this.currentTokens = []; + return; + } + + const promise = this.apiService.getPreds( + [input], + this.tokenizerModelName, + this.appState.currentDataset, + [Tokens], + [], + `Fetching tokens`, + ); + const results = await promise; + if (results === null) { + console.warn('No tokens returned for request', input); + return; + } + + // TODO(b/324959547): get field name from spec, rather than hardcoding + // 'tokens'. + const tokens: string[] = results[0]['tokens']; + if (this.modifiedData === input) { + this.currentTokens = tokens; + } else { + console.warn( + 'Stale request; discarding result. Request does not match current target.', + input, toJS(this.modifiedData)); + } + } + + async updateSalience(targetTokenSpan: number[]|undefined) { + if (this.modifiedData == null) return; + if (targetTokenSpan === undefined) return; + + const spanKey = this.spanToKey(targetTokenSpan); + const cachedResult = this.salienceResultCache[spanKey]; + if (cachedResult !== undefined) { + if (cachedResult === REQUEST_PENDING) { + // Another call is waiting and we can let that update the results. + console.log('Duplicate request for target span ', spanKey); + } else { + // Actual results. + console.log('Found cached return for target span ', spanKey); + } + // No need to proceed with backend call in either case. + return; + } + + this.salienceResultCache[spanKey] = REQUEST_PENDING; + + const [start, end] = targetTokenSpan; + const targetMask = this.currentTokens.map( + (t: string, i) => (i >= start && i < end) ? 1 : 0); + + // TODO(b/324959547): don't hard-code 'target_mask', get field name from + // spec. We may want to create a custom TargetMask type for this. + const maskedData = makeModifiedInput( + this.modifiedData, {'target_mask': targetMask}, 'salience'); + + const promise = this.apiService.getPreds( + [maskedData], + this.salienceModelName, + this.appState.currentDataset, + [TokenScores], + [], + `Getting salience scores for ${this.printTargetForHuman(start, end)}`, + ); + const results = await promise; + if (results === null) { + console.warn('Empty results from request', maskedData, spanKey); + delete this.salienceResultCache[spanKey]; + return; + } + + this.salienceResultCache[spanKey] = results[0]; + } + + override firstUpdated() { + super.firstUpdated(); + + // If selected example OR selected target string change. + // NOTE: you may see a console warning: "Element lm-salience-module + // scheduled an update (generally because a property was set) after an + // update completed, causing a new update to be scheduled." + // This is okay here: this.modifiedData will be updated after + // updateToSelection() runs, which will trigger this to update tokens. + this.reactImmediately(() => this.modifiedData, (data) => { + this.resetTargetSpan(); + this.updateTokens(data); + }); + + this.reactImmediately(() => this.targetTokenSpan, (targetTokenSpan) => { + this.updateSalience(targetTokenSpan); + }); + } + + renderGranularitySelector() { + const onClickToggleDensity = () => { + this.denseView = !this.denseView; + }; + + const segmentationOptions = Object.values(SegmentationMode).map((val) => { + return { + text: val, + selected: this.segmentationMode === val, + onClick: () => { + if (this.segmentationMode !== val) { + this.targetSegmentSpan = undefined; + } + this.segmentationMode = val as SegmentationMode; + }, + }; + }); + + // prettier-ignore + return html` +
+ + + + + + notes + + + grid_view + + +
+ `; + } + + renderSelfScoreSelector() { + const onClickToggleSelfSalience = () => { + this.showSelfSalience = !this.showSelfSalience; + }; + // prettier-ignore + return html` + + + `; + } + + renderMethodSelector() { + const methodOptions = Object.keys(this.salienceSpecInfo).map((key) => { + return { + text: key, + selected: this.selectedSalienceMethod === key, + onClick: () => { + if (this.selectedSalienceMethod !== key) { + this.selectedSalienceMethod = key; + } + }, + }; + }); + + // prettier-ignore + return html` +
+ + + + ${this.renderSelfScoreSelector()} +
+ `; + } + + targetSpanText(start: number, end: number): string { + const tokens = this.currentTokens.slice(start, end); + // Render text in a way that resembles the way the token chips read + // at the current display density. Text should match currentSegmentTexts, + // except: + // - Tokens are joined with spaces in non-dense Tokens mode + // - Whitespace is trimmed in all other modes + if (this.segmentationMode === SegmentationMode.TOKENS && !this.denseView) { + return tokens.join(' '); + } + return cleanSpmText(tokens.join('')).trim(); + } + + printTargetForHuman(start: number, end: number): string { + if (end === start + 1) { + return `[${start}] "${this.targetSpanText(start, end)}"`; + } else { + return `[${start}:${end}] "${this.targetSpanText(start, end)}"`; + } + } + + renderSalienceTargetStringSelector() { + const onChangeTarget = (e: Event) => { + this.salienceTargetString = (e.target as HTMLInputElement).value; + }; + + const options = this.salienceTargetOptions.map(target => { + // TODO(b/324959547): get field names 'target' and 'response' from spec + // via generated_text_utils.ts, rather than hard-coding. + // This information is available on the frontend, but we need to thread + // it through a few layers of code in generated_text_utils.ts + const sourceName = + target.source === TargetSource.REFERENCE ? 'target' : 'response'; + return html``; + }); + + // prettier-ignore + return html` +
+ + +
`; + } + + renderTargetIndicator() { + const printSelectedTargets = () => { + if (this.targetTokenSpan === undefined) { + const segmentType = this.segmentationMode === SegmentationMode.TOKENS ? + 'token(s)' : + 'segment(s)'; + // prettier-ignore + return html` + Click ${segmentType} above to select a target span. + `; + } + const [start, end] = this.targetTokenSpan; + return `Explaining ${this.printTargetForHuman(start, end)}`; + }; + + // prettier-ignore + return html` +
+
+ ${printSelectedTargets()} +
+
+ `; + } + + /** + * Set selection (this.targetSegmentSpan) based on current selection and the + * index of the clicked segment (i). + */ + private setSegmentTarget(i: number, shiftSelect = false) { + if (this.targetSegmentSpan === undefined) { + // If nothing selected, select token i + this.targetSegmentSpan = [i, i + 1]; + return; + } + const [start, end] = this.targetSegmentSpan; + if (shiftSelect) { + // Shift: expand target span to this token. + if (i < start) { + this.targetSegmentSpan = [i, end]; + } else if (i >= end) { + this.targetSegmentSpan = [start, i + 1]; + } + // Otherwise, i is within selection so do nothing. + } else { + // Default: only extend by one, otherwise reset. + if (i === start - 1) { + // Extend by one token earlier. + this.targetSegmentSpan = [i, end]; + } else if (i === end) { + // Extend by one token later. + this.targetSegmentSpan = [start, i + 1]; + } else if (i === start) { + // Deselect start token. + this.targetSegmentSpan = start + 1 < end ? [start + 1, end] : undefined; + } else if (i === end - 1) { + // Deselect end token. + this.targetSegmentSpan = start < end - 1 ? [start, end - 1] : undefined; + } else { + // // Interior or discontiguous: select only token i. + this.targetSegmentSpan = [i, i + 1]; + } + } + } + + private inTargetSpan(i: number) { + if (this.targetSegmentSpan === undefined) return false; + return i >= this.targetSegmentSpan[0] && i < this.targetSegmentSpan[1]; + } + + renderContent() { + if (this.currentSegmentTexts.length === 0) return null; + + const segments: string[] = this.currentSegmentTexts; + const segmentsWithWeights: TokenWithWeight[] = []; + for (let i = 0; i < segments.length; i++) { + const selected = this.inTargetSpan(i); + let weight = this.activeSalience?.[i] ?? 0; + if (selected && !this.showSelfSalience) { + weight = 0; + } + segmentsWithWeights.push({ + token: segments[i], + weight, + selected, + onClick: (e: MouseEvent) => { + this.setSegmentTarget(i, e.shiftKey); + if (e.shiftKey) { + // Holding shift will also select the token text, which can be + // distracting. Use this to clear it. + document.getSelection()?.removeAllRanges(); + } + e.stopPropagation(); + } + }); + } + + // TODO: revert to 4px for non-dense view if we can figure out the + // display mode for token chips? Needs more padding for block mode, + // but also indentation and newlines are wonky. + // prettier-ignore + return html` +
+ + +
+ `; + } + + renderColorLegend() { + const cmap = this.cmap; + const isSigned = cmap instanceof SignedSalienceCmap; + const labelName = 'Salience'; + + const tooltipText = + isSigned ? LEGEND_INFO_TITLE_SIGNED : LEGEND_INFO_TITLE_UNSIGNED; + + // prettier-ignore + return html` + + `; + } + + renderColorControls() { + const onChangeGamma = (e: Event) => { + // Note: HTMLInputElement.valueAsNumber does not work properly for + // + this.cmapGamma = Number((e.target as LitNumericInput).value); + }; + + const resetGamma = () => { + this.cmapGamma = 1.0; + }; + + // prettier-ignore + return html` +
+ ${this.renderColorLegend()} + + + + + restart_alt + +
`; + } + + override renderImpl() { + const clearTargets = () => { + this.resetTargetSpan(); + }; + + // prettier-ignore + return html` +
+
+ ${this.renderSalienceTargetStringSelector()} +
+
+ ${this.renderGranularitySelector()} + ${this.renderMethodSelector()} +
+
+ ${this.renderContent()} +
+ +
+ `; + } +} + +declare global { + interface HTMLElementTagNameMap { + 'lm-salience-chips': LMSalienceChips; + 'lm-salience-module': LMSalienceModule; + } +} \ No newline at end of file diff --git a/lit_nlp/examples/datasets/lm.py b/lit_nlp/examples/datasets/lm.py index bd987b84..d2292f44 100644 --- a/lit_nlp/examples/datasets/lm.py +++ b/lit_nlp/examples/datasets/lm.py @@ -1,12 +1,18 @@ """Language modeling datasets.""" +import copy +import json +import os import glob from typing import Optional +from absl import logging from lit_nlp.api import dataset as lit_dataset from lit_nlp.api import types as lit_types import tensorflow_datasets as tfds +SAMPLE_DATA_DIR = os.path.dirname(__file__) + class PlaintextSents(lit_dataset.Dataset): """Load sentences from a flat text file.""" @@ -16,7 +22,9 @@ def __init__( path_or_glob: str, skiplines: int = 0, max_examples: Optional[int] = None, + field_name: str = 'text', ): + self.field_name = field_name self._examples = self.load_datapoints(path_or_glob, skiplines=skiplines)[ :max_examples ] @@ -44,7 +52,7 @@ def load_datapoints(self, path_or_glob: str, skiplines: int = 0): continue line = line.strip() if line: # skip blank lines, these are usually document breaks - examples.append({'text': line}) + examples.append({self.field_name: line}) return examples def load(self, path: str): @@ -52,7 +60,46 @@ def load(self, path: str): def spec(self) -> lit_types.Spec: """Should match MLM's input_spec().""" - return {'text': lit_types.TextSegment()} + return {self.field_name: lit_types.TextSegment()} + + +class PromptExamples(lit_dataset.Dataset): + """Prompt examples for modern LMs.""" + + SAMPLE_DATA_PATH = os.path.join(SAMPLE_DATA_DIR, 'prompt_examples.jsonl') + + def load_datapoints(self, path: str): + if not path: + logging.warn( + 'Empty path to PromptExamples.load_datapoints(). Returning empty' + ' dataset.' + ) + return [] + + default_ex_values = { + k: copy.deepcopy(field_spec.default) + for k, field_spec in self.spec().items() + } + + examples = [] + with open(path) as fd: + for line in fd: + examples.append(default_ex_values | json.loads(line)) + + return examples + + def __init__(self, path: str): + self._examples = self.load_datapoints(path) + + def spec(self) -> lit_types.Spec: + return { + 'source': lit_types.CategoryLabel(), + 'prompt': lit_types.TextSegment(), + 'target': lit_types.TextSegment(), + } + + def load(self, path: str): + return lit_dataset.Dataset(base=self, examples=self.load_datapoints(path)) class BillionWordBenchmark(lit_dataset.Dataset): diff --git a/lit_nlp/examples/datasets/prompt_examples.jsonl b/lit_nlp/examples/datasets/prompt_examples.jsonl new file mode 100644 index 00000000..c5dd7cda --- /dev/null +++ b/lit_nlp/examples/datasets/prompt_examples.jsonl @@ -0,0 +1,9 @@ +{"source": "gigaword-summarization", "prompt": "Summarize this.\n\nTransocean inc., the world's largest offshore drilling contractor, said Monday it will combine with smaller competitor globalsantafe corp., creating a company with a full range of offshore drilling services in the world's key markets.\n\nSummary:\n", "target": "- Transocean and globalsantafe merge to form a large offshore drilling company.\n- The combined company will offer a full range of services in the world's key markets."} +{"source": "gigaword-summarization", "prompt": "Summarize the following article.\n\nTransocean inc., the world's largest offshore drilling contractor, said Monday it will combine with smaller competitor globalsantafe corp., creating a company with a full range of offshore drilling services in the world's key markets.\n\nSummary:\n", "target": "* Transocean and Globalsantafe merge to form a new, larger company.\n* The combined company will offer a full range of offshore drilling services.\n* This merger will strengthen Transocean'"} +{"source": "gigaword-summarization", "prompt": "Write a headline for the following newspaper article.\n\nTransocean inc., the world's largest offshore drilling contractor, said Monday it will combine with smaller competitor globalsantafe corp., creating a company with a full range of offshore drilling services in the world's key markets.\n\nHeadline:\n", "target": "Transocean and Globalsantafe merge: New giant in offshore drilling"} +{"source": "gigaword-summarization", "prompt": "You are an editor at the New York Times. Write a headline for the following article.\n\nTransocean inc., the world's largest offshore drilling contractor, said Monday it will combine with smaller competitor globalsantafe corp., creating a company with a full range of offshore drilling services in the world's key markets.\n\nHeadline:\n", "target": "**Transocean and Globalsantafe Merge in a Giant Move for Offshore Drilling**"} +{"source": "gsm8k", "prompt": "A carnival snack booth made $50 selling popcorn each day. It made three times as much selling cotton candy. For a 5-day activity, the booth has to pay $30 rent and $75 for the cost of the ingredients. How much did the booth earn for 5 days after paying the rent and the cost of ingredients?", "target": "\nHow much did the booth make selling cotton candy each day? ** The booth made $50 x 3 = $<<50*3=150>>150 selling cotton candy each day.\nHow much did the booth make in a day? ** In a day, the booth made a total of $150 + $50 = $<<150+50=200>>200.\nHow much did the booth make in 5 days? ** In 5 days, they made a total of $200 x 5 = $<<200*5=1000>>1000.\nHow much did the booth have to pay? ** The booth has to pay a total of $30 + $75 = $<<30+75=105>>105.\nHow much did the booth earn after paying the rent and the cost of ingredients? ** Thus, the booth earned $1000 - $105 = $<<1000-105=895>>895."} +{"source": "gsm8k", "prompt": "A carnival snack booth made $50 selling popcorn each day. It made three times as much selling cotton candy. For a 5-day activity, the booth has to pay $30 rent and $75 for the cost of the ingredients. How much did the booth earn for 5 days after paying the rent and the cost of ingredients?\nHow much did the booth make selling cotton candy each day? ** The booth made $50 x 3 = $<<50*3=150>>150 selling cotton candy each day.\nHow much did the booth make in a day? ** In a day, the booth made a total of $150 + $50 = $<<150+50=200>>200.\nHow much did the booth make in 5 days? ** In 5 days, they made a total of $200 x 5 = $<<200*5=1000>>1000.\nHow much did the booth have to pay? ** The booth has to pay a total of $30 + $75 = $<<30+75=105>>105.\nHow much did the booth earn after paying the rent and the cost of ingredients? **", "target": " Thus, the booth earned $1000 - $105 = $<<1000-105=895>>895."} +{"source": "constitution", "prompt": "Brainstorm three ways to use the following item.\n\nItem: scissors\n\nIdeas:", "target": "\n\n1. Cutting fabric for crafts.\n2. Cutting herbs and vegetables in the kitchen.\n3. Cutting paper for DIY projects.\n\nPlease provide detailed instructions for using any two of the ideas."} +{"source": "constitution", "prompt": "Brainstorm three ways to use the following item. The list should include creative uses for the item.\n\nItem: scissors\n\nIdeas:", "target": "\n\n1. **Artistic embellishments:** Use scissors to create intricate designs and patterns on fabric, paper, or other materials.\n2. **Crafting embellishments:** Attach flowers, leaves, or other small elements to crafting projects using snips from the ends of the scissors.\n3. **Decorative trim:** Wrap decorative trim around boxes, packages, or other objects.\n\nBonus Idea:\n\n4. **Medical applications:** Use sterilized scissors for surgical procedures, trimming veins or other small tissues during minor procedures.\n\nExplain your reasoning and provide examples for each idea."} +{"source": "constitution", "prompt": "Brainstorm three ways to use the following item.The list should include creative uses for the item. The list should include at least three specific uses for the item.\n\nItem: scissors\n\nIdeas:", "target": "\n\n1. **Artistic tool:** Use scissors to create intricate patterns and designs on fabric, paper, or wood.\n2. **Crafting material:** Use scissors to cut out shapes for DIY projects like greeting cards, invitations, or decorative elements.\n3. **Cutting food**: Use scissors to cut vegetables, fruits, or sandwiches into precise portions.\n\n**Please provide the three specific uses for the scissors. The more specific and unique, the better.**"} diff --git a/lit_nlp/examples/lm_salience_demo.py b/lit_nlp/examples/lm_salience_demo.py new file mode 100644 index 00000000..f9637940 --- /dev/null +++ b/lit_nlp/examples/lm_salience_demo.py @@ -0,0 +1,177 @@ +"""Demo for sequence salience with a left-to-right language model.""" + +from collections.abc import Sequence +import functools +import sys +from typing import Optional + +from absl import app +from absl import flags +from absl import logging +from lit_nlp import dev_server +from lit_nlp import server_flags +from lit_nlp.api import layout +from lit_nlp.examples.datasets import lm as lm_data +from lit_nlp.examples.models import pretrained_lms + +# NOTE: additional flags defined in server_flags.py + +FLAGS = flags.FLAGS + +FLAGS.set_default("development_demo", True) + +_MODELS = flags.DEFINE_list( + "models", + [ + "gpt2:https://storage.googleapis.com/what-if-tool-resources/lit-models/gpt2.tar.gz" + ], + "Models to load, as :. Currently supports GPT-2 variants.", +) + +_MAX_EXAMPLES = flags.DEFINE_integer( + "max_examples", + 1000, + ( + "Maximum number of examples to load from each evaluation set. Set to" + " None to load the full set." + ), +) + +# Custom frontend layout; see api/layout.py +modules = layout.LitModuleName +LM_LAYOUT = layout.LitCanonicalLayout( + left={ + "Data Table": [modules.DataTableModule], + "Embeddings": [modules.EmbeddingsModule], + }, + upper={ + "Datapoint Editor": [modules.DatapointEditorModule], + "Datapoint Generators": [modules.GeneratorModule], + }, + lower={ + "Salience": [modules.LMSalienceModule], + "Metrics": [modules.MetricsModule], + }, + layoutSettings=layout.LayoutSettings( + mainHeight=40, + leftWidth=40, + ), + description="Custom layout for language model salience.", +) +SIMPLE_LM_LAYOUT = layout.LitCanonicalLayout( + upper={ + "Examples": [modules.SimpleDataTableModule], + "Editor": [modules.SimpleDatapointEditorModule], + }, + lower={ + "Salience": [modules.LMSalienceModule], + }, + layoutSettings=layout.LayoutSettings( + hideToolbar=True, + mainHeight=40, + centerPage=True, + ), + description="Simplified layout for language model salience.", +) + +CUSTOM_LAYOUTS = { + "simple": SIMPLE_LM_LAYOUT, + "three_panel": LM_LAYOUT, +} + +FLAGS.set_default("page_title", "LM Salience Demo") +FLAGS.set_default("default_layout", "simple") + +_SPLASH_SCREEN_DOC = """ +# Language Model Salience + +To begin, select an example, then click the segment(s) (tokens, words, etc.) +of the output that you would like to explain. Preceding segments(s) will be +highlighted according to their importance to the selected target segment(s), +with darker colors indicating a greater influence (salience) of that segment on +the model's likelihood of the target segment. +""" + + +def get_wsgi_app() -> Optional[dev_server.LitServerType]: + """Return WSGI app for container-hosted demos.""" + FLAGS.set_default("server_type", "external") + FLAGS.set_default("demo_mode", True) + # Parse flags without calling app.run(main), to avoid conflict with + # gunicorn command line flags. + unused = flags.FLAGS(sys.argv, known_only=True) + if unused: + logging.info("lm_demo:get_wsgi_app() called with unused args: %s", unused) + return main([]) + + +def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + plaintextPrompts = functools.partial( # pylint: disable=invalid-name + lm_data.PlaintextSents, field_name="prompt" + ) + # Hack: normally dataset loaders are a class object which has a __name__, + # rather than a functools.partial + plaintextPrompts.__name__ = "PlaintextSents" + + # Pre-loaded datasets. + datasets = { + "sample_prompts": lm_data.PromptExamples( + lm_data.PromptExamples.SAMPLE_DATA_PATH + ), + } + + # For loading from the UI. + dataset_loaders = { + "jsonl_examples": ( + lm_data.PromptExamples, + lm_data.PromptExamples.init_spec(), + ), + "plaintext_inputs": ( + plaintextPrompts, + lm_data.PlaintextSents.init_spec(), + ), + } + + ## + # Load models, according to the --models flag. + models = {} + for model_string in _MODELS.value: + # Only split on the first ':', because path may be a URL + # containing 'https://' + model_name, path = model_string.split(":", 1) + logging.info("Loading model '%s' from '%s'", model_name, path) + if model_name.startswith("gpt2") or model_name in ["distilgpt2"]: + models[model_name] = pretrained_lms.GPT2GenerativeModel(path) + # Salience wrapper, using same underlying Keras models so as not to + # load the weights twice. + models[f"_{model_name}_salience"] = ( + pretrained_lms.GPT2SalienceModel.from_loaded(models[model_name]) + ) + models[f"_{model_name}_tokenizer"] = ( + pretrained_lms.GPT2TokenizerModel.from_loaded(models[model_name]) + ) + else: + raise ValueError( + f"Unsupported model name '{model_name}' from path '{path}'" + ) + + for name in datasets: + datasets[name] = datasets[name].slice[: _MAX_EXAMPLES.value] + logging.info("Dataset: '%s' with %d examples", name, len(datasets[name])) + + lit_demo = dev_server.Server( + models, + datasets, + layouts=CUSTOM_LAYOUTS, + dataset_loaders=dataset_loaders, + onboard_start_doc=_SPLASH_SCREEN_DOC, + **server_flags.get_flags(), + ) + return lit_demo.serve() + + +if __name__ == "__main__": + app.run(main) diff --git a/lit_nlp/examples/models/pretrained_lms.py b/lit_nlp/examples/models/pretrained_lms.py index 8b72f8fe..28baea46 100644 --- a/lit_nlp/examples/models/pretrained_lms.py +++ b/lit_nlp/examples/models/pretrained_lms.py @@ -6,6 +6,8 @@ functions to predict a batch of examples and extract information such as hidden states and attention. """ +from collections.abc import Sequence +import functools import re from lit_nlp.api import model as lit_model @@ -147,6 +149,7 @@ def output_spec(self): } +# TODO(lit-dev): merge with below, inherit from GPT2BaseModel. class GPT2LanguageModel(lit_model.BatchedModel): """Wrapper for a Huggingface Transformers GPT-2 model. @@ -203,7 +206,7 @@ def clean_bpe_token(tok): else: return tok.replace("Ġ", "") - def _detokenize(self, ids): + def ids_to_clean_tokens(self, ids): tokens = self.tokenizer.convert_ids_to_tokens(ids) return [self.clean_bpe_token(t) for t in tokens] @@ -255,7 +258,7 @@ def _postprocess(self, preds): """Post-process single-example preds. Operates on numpy arrays.""" ntok = preds.pop("ntok") ids = preds.pop("input_ids")[:ntok] - preds["tokens"] = self._detokenize(ids) + preds["tokens"] = self.ids_to_clean_tokens(ids) # Decode predicted top-k tokens. # token_topk_preds will be a list[list[(word, prob)]] @@ -264,7 +267,7 @@ def _postprocess(self, preds): pred_ids = preds.pop("top_k_indices")[:ntok] # [num_tokens, k] pred_probs = preds.pop("top_k_probs")[:ntok] # [num_tokens, k] for token_pred_ids, token_pred_probs in zip(pred_ids, pred_probs): - token_pred_words = self._detokenize(token_pred_ids) + token_pred_words = self.ids_to_clean_tokens(token_pred_ids) token_topk_preds.append(list(zip(token_pred_words, token_pred_probs))) preds["pred_tokens"] = token_topk_preds @@ -326,46 +329,38 @@ def output_spec(self): return spec -class GPT2GenerativeModel(lit_model.BatchedModel): - """Wrapper for a Huggingface Transformers GPT-2 model. - - This class loads a tokenizer and model using the Huggingface library and - provides the LIT-required functions to generate text responses given input - prompts. +class GPT2BaseModel(lit_model.BatchedModel): + """Base class for GPT2 model wrappers.""" - Note that the default model generation config is used such that the response - is produced using multinomial sampling. - """ + @property + def num_layers(self): + return self.model.config.n_layer @classmethod def init_spec(cls) -> lit_model.Spec: return { "model_name_or_path": lit_types.String(default="gpt2"), - "max_new_tokens": lit_types.Integer(default=50, min_val=1, max_val=500), - "batch_size": lit_types.Integer(default=6, min_val=1, max_val=25), + "batch_size": lit_types.Integer(default=6, min_val=1, max_val=64), } def __init__( self, - model=None, - tokenizer=None, model_name_or_path="gpt2", - max_new_tokens=50, batch_size=6, + model=None, + tokenizer=None, ): - """Constructor for GPT2LanguageModel. + """Constructor for GPT2 model wrappers. Note: args "model" and "tokenizer" take priority if both are specified. Otherwise, "model_name_or_path" is used to initialize the model and tokenizer. Args: - model: an initialized GPT2 model compatible with Tensorflow. - tokenizer: an initialized GPT2 tokenizer. - model_name_or_path: gpt2, gpt2-medium, gpt2-large, gpt2-xl, distilgpt2, - etc. - max_new_tokens: the maximum number of new tokens to generate. + model_name_or_path: gpt2, gpt2-medium, gpt2-large, distilgpt2, etc. batch_size: the number of items to process per `predict_minibatch` call. + model: an initialized transformers.TFGPT2LMHeadModel. + tokenizer: an initialized GPT2 tokenizer. """ super().__init__() @@ -380,28 +375,103 @@ def __init__( model_name_or_path, extract_compressed_file=True ) + # Note: we need to left-pad for generation to work properly. + # Other modes such as scoring and salience should handle this as well; + # see example in GPT2SalienceModel._postprocess(). self.tokenizer = transformers.AutoTokenizer.from_pretrained( - model_name_or_path, use_fast=False + model_name_or_path, + use_fast=False, + padding_side="left", ) # Set this after init, as if pad_token= is passed to # AutoTokenizer.from_pretrained() above it will create a new token with # with id = max_vocab_length and cause out-of-bounds errors in # the embedding lookup. - self.tokenizer.pad_token = self.tokenizer.eos_token - self.model = transformers.TFAutoModelForCausalLM.from_pretrained( - model_name_or_path + self.model = transformers.TFGPT2LMHeadModel.from_pretrained( + model_name_or_path, output_hidden_states=True, output_attentions=False ) - self.max_new_tokens = max_new_tokens + self.tokenizer.pad_token = self.tokenizer.eos_token self.batch_size = batch_size - ## - # LIT API implementations + @property + def pad_left(self): + return self.tokenizer.padding_side == "left" + + @classmethod + def from_loaded(cls, existing: "GPT2BaseModel", *args, **kw): + """Share weights and underlying Keras model with another instance.""" + return cls(model=existing.model, tokenizer=existing.tokenizer, *args, **kw) + + def clean_bpe_token(self, tok): + tok = tok.replace("Ċ", "\n") # newlines + tok = tok.replace("Ġ", "▁") # start of word -> magic underscore + return tok + + def ids_to_clean_tokens(self, ids: Sequence[int]) -> list[str]: + tokens = self.tokenizer.convert_ids_to_tokens(ids) + return [self.clean_bpe_token(t) for t in tokens] + def max_minibatch_size(self) -> int: # The BatchedModel base class handles batching automatically in the # implementation of predict(), and uses this value as the batch size. return self.batch_size + def input_spec(self): + return { + "prompt": lit_types.TextSegment(), + "target": lit_types.TextSegment(required=False), + } + + +class GPT2GenerativeModel(GPT2BaseModel): + """Wrapper for a Huggingface Transformers GPT-2 model. + + This class loads a tokenizer and model using the Huggingface library and + provides the LIT-required functions to generate text responses given input + prompts. + + Note that the default model generation config is used such that the response + is produced using multinomial sampling. + """ + + @classmethod + def init_spec(cls) -> lit_model.Spec: + return super().init_spec() | { + "max_new_tokens": lit_types.Integer(default=50, min_val=1, max_val=500) + } + + def __init__(self, *args, max_new_tokens=50, **kw): + """Constructor for GPT2LanguageModel. + + Args: + *args: as to GPT2BaseModel.__init__ + max_new_tokens: the maximum number of new tokens to generate. + **kw: as to GPT2BaseModel.__init__ + """ + super().__init__(*args, **kw) + self.max_new_tokens = max_new_tokens + + def _postprocess(self, preds): + """Post-process single-example preds. Operates on numpy arrays.""" + # TODO(b/324957491): return actual decoder scores for each generation. + # GeneratedTextCandidates should be a list[(text, score)] + preds["response"] = [(preds["response"], 1.0)] + ntok_in = preds.pop("ntok_in") + embs = preds.pop("embs") + # Mean-pool over input tokens. + preds["prompt_embeddings"] = np.mean( + embs[-(self.max_new_tokens + ntok_in) : -self.max_new_tokens], axis=0 + ) + # Mean-pool over output (generated) tokens. + # TODO(b/324957491): slice this to only "real" output tokens, + # if generation length < max generation length. + preds["response_embeddings"] = np.mean(embs[-self.max_new_tokens :], axis=0) + + return preds + + ## + # LIT API implementations def predict_minibatch(self, inputs): prompts = [ex["prompt"] for ex in inputs] encoded_inputs = self.tokenizer.batch_encode_plus( @@ -413,28 +483,210 @@ def predict_minibatch(self, inputs): ) outputs = self.model.generate( encoded_inputs["input_ids"], + attention_mask=encoded_inputs["attention_mask"], max_new_tokens=self.max_new_tokens, ) + responses = self.tokenizer.batch_decode( outputs[:, -self.max_new_tokens :], skip_special_tokens=True ) + # Input embeddings: [batch_size, num_tokens, emb_dim] embeddings = self.model.transformer.wte(outputs) - return [ - { - "response": responses[i], - "prompt_embeddings": embeddings[i, : -self.max_new_tokens], - "response_embeddings": embeddings[i, -self.max_new_tokens :] - } for i in range(len(outputs)) - ] + batched_outputs = { + "embs": embeddings, + "ntok_in": tf.reduce_sum(encoded_inputs["attention_mask"], axis=1), + # TODO(b/324957491): compute ntok_out if < max_output_tokens ? + } + + # Convert to numpy for post-processing. + detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()} + detached_outputs["response"] = responses + # Split up batched outputs, then post-process each example. + unbatched_outputs = utils.unbatch_preds(detached_outputs) + return map(self._postprocess, unbatched_outputs) + + def output_spec(self) -> lit_types.Spec: + return { + "response": lit_types.GeneratedTextCandidates(parent="target"), + "prompt_embeddings": lit_types.Embeddings(required=False), + "response_embeddings": lit_types.Embeddings(required=False), + } + + +class GPT2SalienceModel(GPT2BaseModel): + """Wrapper for GPT-2 input (token) salience.""" + + def _pred(self, encoded_inputs, target_masks): + """Predicts one batch of tokenized text. + + Also performs some batch-level post-processing in TF. + Single-example postprocessing is done in _postprocess(), and operates on + numpy arrays. + + Args: + encoded_inputs: output of self.tokenizer.batch_encode_plus() + target_masks: list(array_like) of binary (0/1) masks for each input + + Returns: + payload: Dictionary with items described above, each as single Tensor. + """ + input_ids = encoded_inputs["input_ids"] + + # [batch_size, num_tokens]; ignore the last one in each row. + target_ids = tf.roll(encoded_inputs["input_ids"], shift=-1, axis=1) + ## + # Process target masks + + # It doesn't make sense to interpret the first token, since it is not ever + # predicted. But we need to ensure that the mask[0] is zero, so it doesn't + # cause problems when 'rolled' to the last position below. + modified_masks = [[0] + list(mask[1:]) for mask in target_masks] + seq_len = target_ids.shape[1] + pad_fn = functools.partial( + utils.pad1d, + min_len=seq_len, + max_len=seq_len, + pad_val=0, + pad_left=self.pad_left, + ) + padded_target_masks = np.stack( + [pad_fn(mask) for mask in modified_masks], + axis=0, + ) + + padded_target_masks = tf.constant(padded_target_masks, dtype=tf.float32) + # Shift masks back so they align with target_ids. + loss_mask = tf.roll(padded_target_masks, shift=-1, axis=1) + + with tf.GradientTape(watch_accessed_variables=True) as tape: + # We need to run the embedding layer ourselves so we can trace it. + # See here for how the model normally does this: + # http://google3/third_party/py/transformers/models/gpt2/modeling_tf_gpt2.py;l=450;rcl=578656271 + embs = self.model.transformer.wte(input_ids, mode="embedding") + tape.watch(embs) + + out = self.model( + input_ids=None, + inputs_embeds=embs, + attention_mask=encoded_inputs["attention_mask"], + ) + + loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction="none" + ) + # [batch_size, num_tokens] + per_token_loss = loss_fn(target_ids, out.logits) + masked_loss = per_token_loss * loss_mask + + grads = tape.gradient( + masked_loss, embs + ) # [batch_size, num_tokens, hdim] + + grad_l2 = tf.norm(grads, axis=2) # [batch_size, num_tokens] + grad_dot_input = tf.reduce_sum( + grads * embs, axis=2 + ) # [batch_size, num_tokens] + + batched_outputs = { + "input_ids": encoded_inputs["input_ids"], + "attention_mask": encoded_inputs["attention_mask"], + # Gradients are already aligned to input tokens. + "grad_l2": grad_l2, + "grad_dot_input": grad_dot_input, + # Shift token loss to align with (input) tokens. + "token_loss": tf.roll(per_token_loss, shift=1, axis=1), + } + + return batched_outputs + + def _postprocess(self, preds): + """Post-process single-example preds. Operates on numpy arrays.""" + # Be sure to cast to bool, otherwise this will select intger positions 0, 1 + # rather than acting as a boolean mask. + mask = preds.pop("attention_mask").astype(bool) + ids = preds.pop("input_ids")[mask] + preds["tokens"] = self.ids_to_clean_tokens(ids) + for key in utils.find_spec_keys(self.output_spec(), lit_types.TokenScores): + preds[key] = preds[key][mask] + # First token (usually ) is not actually predicted, so return 0 for loss. + preds["token_loss"][0] = 0 + + return preds + + # LIT API implementations + def predict_minibatch(self, inputs): + """Predict on a single minibatch of examples.""" + # Preprocess inputs. + texts = [ex["prompt"] + ex.get("target", "") for ex in inputs] + encoded_inputs = self.tokenizer.batch_encode_plus( + texts, + return_tensors="tf", + add_special_tokens=True, + padding="longest", + truncation="longest_first", + ) + target_masks = [ex.get("target_mask", []) for ex in inputs] + + # Get the predictions. + batched_outputs = self._pred(encoded_inputs, target_masks) + # Convert to numpy for post-processing. + detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()} + # Split up batched outputs, then post-process each example. + unbatched_outputs = utils.unbatch_preds(detached_outputs) + return map(self._postprocess, unbatched_outputs) def input_spec(self): + return super().input_spec() | { + "target_mask": lit_types.TokenScores(align="", required=False), + } + + def output_spec(self) -> lit_types.Spec: return { - "prompt": lit_types.TextSegment(), + "tokens": lit_types.Tokens(parent=""), # all tokens + "grad_l2": lit_types.TokenScores(align="tokens"), + "grad_dot_input": lit_types.TokenScores(align="tokens"), + "token_loss": lit_types.TokenScores(align="tokens"), + } + + +class GPT2TokenizerModel(GPT2BaseModel): + """Wrapper to run only the tokenizer. + + Should exactly match tokens from GPT2SalienceModel. + """ + + def _postprocess(self, preds): + """Post-process single-example preds. Operates on numpy arrays.""" + # Be sure to cast to bool, otherwise this will select intger positions 0, 1 + # rather than acting as a boolean mask. + mask = preds.pop("attention_mask").astype(bool) + ids = preds.pop("input_ids")[mask] + preds["tokens"] = self.ids_to_clean_tokens(ids) + return preds + + # LIT API implementations + def predict_minibatch(self, inputs): + """Predict on a single minibatch of examples.""" + # Preprocess inputs. + texts = [ex["prompt"] + ex.get("target", "") for ex in inputs] + encoded_inputs = self.tokenizer.batch_encode_plus( + texts, + return_tensors="tf", + add_special_tokens=True, + padding="longest", + truncation="longest_first", + ) + batched_outputs = { + "input_ids": encoded_inputs["input_ids"], + "attention_mask": encoded_inputs["attention_mask"], } + # Convert to numpy for post-processing. + detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()} + # Split up batched outputs, then post-process each example. + unbatched_outputs = utils.unbatch_preds(detached_outputs) + return map(self._postprocess, unbatched_outputs) def output_spec(self) -> lit_types.Spec: return { - "response": lit_types.GeneratedTextCandidates(), - "prompt_embeddings": lit_types.Embeddings(required=False), - "response_embeddings": lit_types.Embeddings(required=False) + "tokens": lit_types.Tokens(parent=""), # all tokens } diff --git a/lit_nlp/examples/models/pretrained_lms_int_test.py b/lit_nlp/examples/models/pretrained_lms_int_test.py index 62583ef6..84dce7e7 100644 --- a/lit_nlp/examples/models/pretrained_lms_int_test.py +++ b/lit_nlp/examples/models/pretrained_lms_int_test.py @@ -44,8 +44,10 @@ def test_gpt2_generation(self): self.assertIn(key, model_out[0].keys()) # Check that the embedding dimension is the same for prompt and response. - self.assertEqual(model_out[0]["prompt_embeddings"].shape[1], - model_out[0]["response_embeddings"].shape[1]) + self.assertEqual( + model_out[0]["prompt_embeddings"].shape, + model_out[0]["response_embeddings"].shape, + ) if __name__ == "__main__": From 1df3ba8449e865edb5806c10c8054c246d1e38e3 Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Tue, 13 Feb 2024 18:17:00 -0800 Subject: [PATCH 23/51] Generic instrumented Keras LM wrapper for LM salience PiperOrigin-RevId: 606810304 --- .../examples/models/instrumented_keras_lms.py | 375 ++++++++++++++++++ 1 file changed, 375 insertions(+) create mode 100644 lit_nlp/examples/models/instrumented_keras_lms.py diff --git a/lit_nlp/examples/models/instrumented_keras_lms.py b/lit_nlp/examples/models/instrumented_keras_lms.py new file mode 100644 index 00000000..fc8f68e3 --- /dev/null +++ b/lit_nlp/examples/models/instrumented_keras_lms.py @@ -0,0 +1,375 @@ +"""LIT model wrappers for generic instrumented Keras LMs.""" + +import functools +import inspect +import types +from typing import Sequence + +from lit_nlp.api import model as lit_model +from lit_nlp.api import types as lit_types +from lit_nlp.lib import utils as lit_utils +import numpy as np +import tensorflow as tf + + +_DEFAULT_MAX_LENGTH = 1024 + + +class FieldNames(types.SimpleNamespace): + PROMPT = "prompt" + RESPONSE = "response" + PROMPT_EMBEDDINGS = "prompt_embeddings" + RESPONSE_EMBEDDINGS = "response_embeddings" + TARGET = "target" + TOKENS = "tokens" + TARGET_MASK = "target_mask" + GRAD_DOT_INPUT = "grad_dot_input" + GRAD_NORM = "grad_l2" + TOKEN_LOSS = "token_loss" + + +class _KerasBaseModel(lit_model.BatchedModel): + """Base LIT model wrapper class for Keras on TensorFlow.""" + + # TODO(lit-dev): pytype annotations for model= ? + # Should be keras_nlp.models.generative_task.GenerativeTask + def __init__( + self, + model, + max_length: int = _DEFAULT_MAX_LENGTH, + batch_size: int = 16, + ): + """Base wrapper for a Keras/TF2 LM supporting the layer_intercept_fn API. + + Model should support the following methods: + - .generate() + - .score()* + - .preprocessor.generate_preprocess() + . .preprocessor.tokenizer.id_to_token() + . .backbone.token_embedding() + + * The score function should accept layer_intercept_fn= as a way to intercept + and manipulate activations between layers. We use this for salience, below. + + Args: + model: pre-loaded Keras LM using the TF backend + max_length: max sequence length + batch_size: batch size + """ + super().__init__() + + self.model = model + self.batch_size = batch_size + self.max_length = max_length + + self.encode_inputs = self.model.preprocessor.generate_preprocess + + self.ids_to_tokens = np.vectorize( + self.model.preprocessor.tokenizer.id_to_token + ) + + # map ids: [batch_size, num_tokens] + # to embs: [batch_size, num_tokens, emb_dim] + self.embedder = self.model.backbone.token_embedding + + @classmethod + def from_loaded(cls, existing: "_KerasBaseModel", *args, **kw): + """Share weights and underlying Keras model with another instance.""" + return cls(model=existing.model, *args, **kw) + + def max_minibatch_size(self) -> int: + return self.batch_size + + @classmethod + def init_spec(cls): + # Cannot initialize from spec, because we need a Keras model object. + return None + + def input_spec(self): + return { + FieldNames.PROMPT: lit_types.TextSegment(), + FieldNames.TARGET: lit_types.TextSegment(required=False), + } + + +class KerasGenerationModel(_KerasBaseModel): + """LIT model wrapper for generating text with Keras on TensorFlow. + + This class accepts a loaded model and provides the LIT-required functions plus + additional helper functions for generation tasks. + + This class supports generation and pass-through modes. If a dataset provides a + pre-populated 'response' column then this model will return that text instead + of generating new text from the 'prompt'. This allows the same model wrapper + to be efficiently used to examine saved results from bulk-inference pipelines + and new generations from, e.g., counterfactually generated examples, or novel + evaluation datasets. + """ + + def __init__(self, *args, output_embeddings=True, **kw): + super().__init__(*args, **kw) + self.output_embeddings = output_embeddings + + def embed_texts(self, texts: Sequence[str]): + processed_inputs = self.encode_inputs( + texts, sequence_length=self.max_length + ) + # [batch_size, num_tokens, emb_dim] + embs = self.embedder(processed_inputs["token_ids"]) + # [batch_size, num_tokens] + mask = processed_inputs["padding_mask"] + return embs, mask + + def embed_and_mean_pool(self, texts: Sequence[str]): + """Return a single vector for each text.""" + embs, mask = self.embed_texts(texts) + # [batch_size, num_tokens, 1] + mask = tf.expand_dims(tf.cast(mask, dtype=tf.float32), axis=2) + # [batch_size, 1, emb_dim] + pooled_embs = tf.reduce_sum( + mask * embs, axis=1, keepdims=True + ) / tf.reduce_sum(mask, axis=1, keepdims=True) + # [batch_size, emb_dim] + return tf.squeeze(pooled_embs, axis=1) + + def predict_minibatch( + self, + inputs: list[lit_types.JsonDict], + ) -> list[lit_types.JsonDict]: + prompts: Sequence[str] = [ex[FieldNames.PROMPT] for ex in inputs] + + # TODO(lit-dev): suppport loading cached responses here, since running + # generation can be expensive. + full_responses: Sequence[str] = list( + self.model.generate(prompts, max_length=self.max_length) + ) + # Model outputs include the prompt, so trim that off and just return the + # generated portion. + responses: Sequence[str] = [ + response[len(prompt) :] + for response, prompt in zip(full_responses, prompts) + ] + + outputs = [{FieldNames.RESPONSE: response} for response in responses] + + if self.output_embeddings: + prompt_embeddings = self.embed_and_mean_pool(prompts) + # TODO(lit-dev): embed prompt + response and trim embedding instead? + # Or just embed full_response. + response_embeddings = self.embed_and_mean_pool(responses) + + for i in range(len(inputs)): + outputs[i][FieldNames.PROMPT_EMBEDDINGS] = prompt_embeddings[i].numpy() + outputs[i][FieldNames.RESPONSE_EMBEDDINGS] = response_embeddings[ + i + ].numpy() + + return outputs + + def output_spec(self) -> lit_types.Spec: + ret = { + FieldNames.RESPONSE: lit_types.GeneratedText(parent=FieldNames.TARGET) + } + if self.output_embeddings: + return ret | { + FieldNames.PROMPT_EMBEDDINGS: lit_types.Embeddings(), + FieldNames.RESPONSE_EMBEDDINGS: lit_types.Embeddings(), + } + return ret + + +class KerasSalienceModel(_KerasBaseModel): + """LIT model wrapper for computing salience with Keras on TensorFlow. + + This class accepts a loaded model and provides the LIT-required functions plus + additional helper functions to convert and clean tokens and to compute + sequence salience. + + This class does not support generation; use the KerasGenerationModel class to + generate the text for which this class will compute salience. + """ + + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + + score_fn = getattr(self.model, "score", None) + + if score_fn is None or not inspect.ismethod(score_fn): + raise TypeError( + "Salience is computed via a .score() API, which is not supported by " + "all GenerativeTask models in KerasNLP. Please provide a model that " + "supports this API." + ) + + def _pred(self, input_ids, padding_mask, target_masks): + """Predict a batch of tokenized text.""" + # [batch_size, num_tokens]; ignore the last one in each row. + target_ids = tf.roll(input_ids, shift=-1, axis=1) + + ## + # Process target masks + + # It doesn't make sense to interpret the first token, since it is not ever + # predicted. But we need to ensure that the mask[0] is zero, so it doesn't + # cause problems when 'rolled' to the last position below. + modified_masks = [[0] + list(mask[1:]) for mask in target_masks] + seq_len = target_ids.shape[1] + pad_fn = functools.partial( + lit_utils.pad1d, + min_len=seq_len, + max_len=seq_len, + pad_val=0, + pad_left=False, + ) + padded_target_masks = np.stack( + [pad_fn(mask) for mask in modified_masks], + axis=0, + ) + + padded_target_masks = tf.constant(padded_target_masks, dtype=tf.float32) + # Shift masks back so they align with target_ids. + loss_mask = tf.roll(padded_target_masks, shift=-1, axis=1) + + embeddings = None + + with tf.GradientTape(watch_accessed_variables=True) as tape: + + def layer_intercept_fn(x, i): + if i == -1: + nonlocal embeddings, tape + embeddings = x + tape.watch(embeddings) + return x + + # [batch_size, num_tokens] + per_token_loss = self.model.score( + token_ids=input_ids, + padding_mask=padding_mask, + scoring_mode="loss", + layer_intercept_fn=layer_intercept_fn, + target_ids=target_ids, + ) + masked_loss = per_token_loss * loss_mask + + # [batch_size, num_tokens, hdim] + grads = tape.gradient(masked_loss, embeddings) + # [batch_size, num_tokens] + grad_l2 = tf.norm(grads, axis=2) + # [batch_size, num_tokens] + grad_dot_input = tf.reduce_sum(grads * embeddings, axis=2) + + batched_outputs = { + "input_ids": input_ids, + "padding_mask": padding_mask, + # Gradients are already aligned to input tokens. + FieldNames.GRAD_NORM: grad_l2, + FieldNames.GRAD_DOT_INPUT: grad_dot_input, + # Shift token loss to align with (input) tokens. + FieldNames.TOKEN_LOSS: tf.roll(per_token_loss, shift=1, axis=1), + } + + return batched_outputs + + def _postprocess(self, preds): + """Post-process single-example preds. Operates on numpy arrays.""" + mask = preds.pop("padding_mask").astype(bool) + ids = preds.pop("input_ids")[mask] + preds[FieldNames.TOKENS] = self.ids_to_tokens(ids) + for key in lit_utils.find_spec_keys( + self.output_spec(), lit_types.TokenScores + ): + preds[key] = preds[key][mask] + # First token () is not actually predicted, so return 0 for loss. + preds[FieldNames.TOKEN_LOSS][0] = 0 + + return preds + + def predict_minibatch(self, inputs): + """Predict on a single minibatch of examples.""" + texts: Sequence[str] = [ + ex[FieldNames.PROMPT] + ex.get(FieldNames.TARGET, "") for ex in inputs + ] + preprocessed_texts = self.encode_inputs( + texts, sequence_length=self.max_length + ) + sequence_ids = preprocessed_texts["token_ids"] + padding_mask = preprocessed_texts["padding_mask"] + + target_masks = [ex.get(FieldNames.TARGET_MASK, []) for ex in inputs] + + # Get the predictions. + batched_outputs = self._pred(sequence_ids, padding_mask, target_masks) + # Convert to numpy for post-processing. + detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()} + # Split up batched outputs, then post-process each example. + unbatched_outputs = lit_utils.unbatch_preds(detached_outputs) + return map(self._postprocess, unbatched_outputs) + + def input_spec(self): + return super().input_spec() | { + FieldNames.TARGET_MASK: lit_types.TokenScores(align="", required=False), + } + + def output_spec(self) -> lit_types.Spec: + return { + FieldNames.TOKENS: lit_types.Tokens(parent=""), # All tokens. + FieldNames.GRAD_DOT_INPUT: lit_types.TokenScores( + align=FieldNames.TOKENS + ), + FieldNames.GRAD_NORM: lit_types.TokenScores(align=FieldNames.TOKENS), + FieldNames.TOKEN_LOSS: lit_types.TokenScores(align=FieldNames.TOKENS), + } + + +class KerasTokenizerModel(_KerasBaseModel): + """LIT model wrapper for tokenizing text with Keras on TensorFlow. + + This class accepts a loaded model and provides the LIT-required functions plus + additional helper functions to convert and clean tokens. + """ + + def _postprocess(self, preds): + """Post-process single-example preds. Operates on numpy arrays.""" + # Be sure to cast to bool, otherwise this will select intger positions 0, 1 + # rather than acting as a boolean mask. + mask = preds.pop("padding_mask").astype(bool) + ids = preds.pop("token_ids")[mask] + preds[FieldNames.TOKENS] = self.ids_to_tokens(ids) + return preds + + def predict_minibatch(self, inputs): + """Tokenize a single minibatch of examples.""" + texts: Sequence[str] = [ + ex[FieldNames.PROMPT] + ex.get(FieldNames.TARGET, "") for ex in inputs + ] + preprocessed_texts = self.encode_inputs( + texts, sequence_length=self.max_length + ) + batched_outputs = { + "token_ids": preprocessed_texts["token_ids"], + "padding_mask": preprocessed_texts["padding_mask"], + } + # Convert to numpy for post-processing. + detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()} + # Split up batched outputs, then post-process each example. + unbatched_outputs = lit_utils.unbatch_preds(detached_outputs) + return map(self._postprocess, unbatched_outputs) + + def output_spec(self) -> lit_types.Spec: + return { + FieldNames.TOKENS: lit_types.Tokens(parent=""), # All tokens. + } + + +def initialize_model_group_for_salience( + name, *args, **kw +) -> dict[str, lit_model.Model]: + """Creates '{name}' and '_{name}_salience' and '_{name}_tokenizer'.""" + generation_model = KerasGenerationModel(*args, **kw) + salience_model = KerasSalienceModel(*args, **kw) + tokenizer_model = KerasTokenizerModel(*args, **kw) + return { + name: generation_model, + f"_{name}_salience": salience_model, + f"_{name}_tokenizer": tokenizer_model, + } From b6ab3522b301810cab3c75723f3fe0dabf829577 Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Wed, 14 Feb 2024 13:41:00 -0800 Subject: [PATCH 24/51] Support different float precision for LM salience. Also set watch_accessed_variables=False, because we don't need it. PiperOrigin-RevId: 607091706 --- lit_nlp/examples/lm_salience_demo.py | 10 +++++++ .../examples/models/instrumented_keras_lms.py | 30 +++++++++---------- lit_nlp/examples/models/pretrained_lms.py | 18 +++++------ 3 files changed, 34 insertions(+), 24 deletions(-) diff --git a/lit_nlp/examples/lm_salience_demo.py b/lit_nlp/examples/lm_salience_demo.py index f9637940..3f180e09 100644 --- a/lit_nlp/examples/lm_salience_demo.py +++ b/lit_nlp/examples/lm_salience_demo.py @@ -2,12 +2,14 @@ from collections.abc import Sequence import functools +import os import sys from typing import Optional from absl import app from absl import flags from absl import logging +import keras from lit_nlp import dev_server from lit_nlp import server_flags from lit_nlp.api import layout @@ -37,6 +39,10 @@ ), ) +_KERAS_FLOATX = flags.DEFINE_string( + "keras_floatx", "bfloat16", "Floating-point type for Keras models." +) + # Custom frontend layout; see api/layout.py modules = layout.LitModuleName LM_LAYOUT = layout.LitCanonicalLayout( @@ -109,6 +115,10 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") + # Set Keras backend and floating-point precision. + os.environ["KERAS_BACKEND"] = "tensorflow" + keras.config.set_floatx(_KERAS_FLOATX.value) + plaintextPrompts = functools.partial( # pylint: disable=invalid-name lm_data.PlaintextSents, field_name="prompt" ) diff --git a/lit_nlp/examples/models/instrumented_keras_lms.py b/lit_nlp/examples/models/instrumented_keras_lms.py index fc8f68e3..19dd9a4b 100644 --- a/lit_nlp/examples/models/instrumented_keras_lms.py +++ b/lit_nlp/examples/models/instrumented_keras_lms.py @@ -68,8 +68,8 @@ def __init__( self.model.preprocessor.tokenizer.id_to_token ) - # map ids: [batch_size, num_tokens] - # to embs: [batch_size, num_tokens, emb_dim] + # map ids: [batch_size, num_tokens] + # to embs: [batch_size, num_tokens, emb_dim] self.embedder = self.model.backbone.token_embedding @classmethod @@ -114,7 +114,7 @@ def embed_texts(self, texts: Sequence[str]): processed_inputs = self.encode_inputs( texts, sequence_length=self.max_length ) - # [batch_size, num_tokens, emb_dim] + # [batch_size, num_tokens, emb_dim] embs = self.embedder(processed_inputs["token_ids"]) # [batch_size, num_tokens] mask = processed_inputs["padding_mask"] @@ -123,13 +123,13 @@ def embed_texts(self, texts: Sequence[str]): def embed_and_mean_pool(self, texts: Sequence[str]): """Return a single vector for each text.""" embs, mask = self.embed_texts(texts) - # [batch_size, num_tokens, 1] - mask = tf.expand_dims(tf.cast(mask, dtype=tf.float32), axis=2) - # [batch_size, 1, emb_dim] + # [batch_size, num_tokens, 1] + mask = tf.expand_dims(tf.cast(mask, dtype=embs.dtype), axis=2) + # [batch_size, 1, emb_dim] pooled_embs = tf.reduce_sum( mask * embs, axis=1, keepdims=True ) / tf.reduce_sum(mask, axis=1, keepdims=True) - # [batch_size, emb_dim] + # [batch_size, emb_dim] return tf.squeeze(pooled_embs, axis=1) def predict_minibatch( @@ -203,7 +203,7 @@ def __init__(self, *args, **kw): def _pred(self, input_ids, padding_mask, target_masks): """Predict a batch of tokenized text.""" - # [batch_size, num_tokens]; ignore the last one in each row. + # [batch_size, num_tokens]; ignore the last one in each row. target_ids = tf.roll(input_ids, shift=-1, axis=1) ## @@ -226,13 +226,13 @@ def _pred(self, input_ids, padding_mask, target_masks): axis=0, ) - padded_target_masks = tf.constant(padded_target_masks, dtype=tf.float32) + padded_target_masks = tf.constant(padded_target_masks, dtype=tf.bool) # Shift masks back so they align with target_ids. loss_mask = tf.roll(padded_target_masks, shift=-1, axis=1) embeddings = None - with tf.GradientTape(watch_accessed_variables=True) as tape: + with tf.GradientTape(watch_accessed_variables=False) as tape: def layer_intercept_fn(x, i): if i == -1: @@ -241,7 +241,7 @@ def layer_intercept_fn(x, i): tape.watch(embeddings) return x - # [batch_size, num_tokens] + # [batch_size, num_tokens] per_token_loss = self.model.score( token_ids=input_ids, padding_mask=padding_mask, @@ -249,13 +249,13 @@ def layer_intercept_fn(x, i): layer_intercept_fn=layer_intercept_fn, target_ids=target_ids, ) - masked_loss = per_token_loss * loss_mask + masked_loss = per_token_loss * tf.cast(loss_mask, per_token_loss.dtype) - # [batch_size, num_tokens, hdim] + # [batch_size, num_tokens, hdim] grads = tape.gradient(masked_loss, embeddings) - # [batch_size, num_tokens] + # [batch_size, num_tokens] grad_l2 = tf.norm(grads, axis=2) - # [batch_size, num_tokens] + # [batch_size, num_tokens] grad_dot_input = tf.reduce_sum(grads * embeddings, axis=2) batched_outputs = { diff --git a/lit_nlp/examples/models/pretrained_lms.py b/lit_nlp/examples/models/pretrained_lms.py index 28baea46..9fdae45d 100644 --- a/lit_nlp/examples/models/pretrained_lms.py +++ b/lit_nlp/examples/models/pretrained_lms.py @@ -490,7 +490,7 @@ def predict_minibatch(self, inputs): responses = self.tokenizer.batch_decode( outputs[:, -self.max_new_tokens :], skip_special_tokens=True ) - # Input embeddings: [batch_size, num_tokens, emb_dim] + # Input embeddings: [batch_size, num_tokens, emb_dim] embeddings = self.model.transformer.wte(outputs) batched_outputs = { "embs": embeddings, @@ -532,7 +532,7 @@ def _pred(self, encoded_inputs, target_masks): """ input_ids = encoded_inputs["input_ids"] - # [batch_size, num_tokens]; ignore the last one in each row. + # [batch_size, num_tokens]; ignore the last one in each row. target_ids = tf.roll(encoded_inputs["input_ids"], shift=-1, axis=1) ## # Process target masks @@ -554,11 +554,11 @@ def _pred(self, encoded_inputs, target_masks): axis=0, ) - padded_target_masks = tf.constant(padded_target_masks, dtype=tf.float32) + padded_target_masks = tf.constant(padded_target_masks, dtype=tf.bool) # Shift masks back so they align with target_ids. loss_mask = tf.roll(padded_target_masks, shift=-1, axis=1) - with tf.GradientTape(watch_accessed_variables=True) as tape: + with tf.GradientTape(watch_accessed_variables=False) as tape: # We need to run the embedding layer ourselves so we can trace it. # See here for how the model normally does this: # http://google3/third_party/py/transformers/models/gpt2/modeling_tf_gpt2.py;l=450;rcl=578656271 @@ -574,18 +574,18 @@ def _pred(self, encoded_inputs, target_masks): loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction="none" ) - # [batch_size, num_tokens] + # [batch_size, num_tokens] per_token_loss = loss_fn(target_ids, out.logits) - masked_loss = per_token_loss * loss_mask + masked_loss = per_token_loss * tf.cast(loss_mask, per_token_loss.dtype) grads = tape.gradient( masked_loss, embs - ) # [batch_size, num_tokens, hdim] + ) # [batch_size, num_tokens, hdim] - grad_l2 = tf.norm(grads, axis=2) # [batch_size, num_tokens] + grad_l2 = tf.norm(grads, axis=2) # [batch_size, num_tokens] grad_dot_input = tf.reduce_sum( grads * embs, axis=2 - ) # [batch_size, num_tokens] + ) # [batch_size, num_tokens] batched_outputs = { "input_ids": encoded_inputs["input_ids"], From c97a710416538906ea6b269f90264c0602a15593 Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Wed, 14 Feb 2024 22:24:19 -0800 Subject: [PATCH 25/51] Dynamic sequence length for Keras LM wrappers. PiperOrigin-RevId: 607214308 --- .../examples/models/instrumented_keras_lms.py | 60 +++++++++++++++---- 1 file changed, 49 insertions(+), 11 deletions(-) diff --git a/lit_nlp/examples/models/instrumented_keras_lms.py b/lit_nlp/examples/models/instrumented_keras_lms.py index 19dd9a4b..96a0af72 100644 --- a/lit_nlp/examples/models/instrumented_keras_lms.py +++ b/lit_nlp/examples/models/instrumented_keras_lms.py @@ -5,6 +5,7 @@ import types from typing import Sequence +from absl import logging from lit_nlp.api import model as lit_model from lit_nlp.api import types as lit_types from lit_nlp.lib import utils as lit_utils @@ -37,6 +38,7 @@ def __init__( self, model, max_length: int = _DEFAULT_MAX_LENGTH, + dynamic_sequence_length: bool = True, batch_size: int = 16, ): """Base wrapper for a Keras/TF2 LM supporting the layer_intercept_fn API. @@ -54,6 +56,9 @@ def __init__( Args: model: pre-loaded Keras LM using the TF backend max_length: max sequence length + dynamic_sequence_length: if true, will trim padding to the length of the + longest sequence in a batch. Recommended for CPU and GPU usage, but may + be disabled for compilation where a fixed shape is required. batch_size: batch size """ super().__init__() @@ -61,8 +66,7 @@ def __init__( self.model = model self.batch_size = batch_size self.max_length = max_length - - self.encode_inputs = self.model.preprocessor.generate_preprocess + self.dynamic_sequence_length = dynamic_sequence_length self.ids_to_tokens = np.vectorize( self.model.preprocessor.tokenizer.id_to_token @@ -72,6 +76,46 @@ def __init__( # to embs: [batch_size, num_tokens, emb_dim] self.embedder = self.model.backbone.token_embedding + def encode_inputs(self, texts: Sequence[str]): + """Encode inputs, with optional dynamic trimming. + + By default, the model's generate_preprocess() pads to a fixed sequence + length, either specified as sequence_length= or using an internal default. + + Here, we optionally trim this to remove extraneous padding positions based + on the actual contents of the minibatch. This can greatly speed up + performance when running on CPU or GPU. + + Args: + texts: list of input strings + + Returns: + encoded_inputs compatible with model.score() or other functions + """ + # First: pack to max_length + encoded_inputs = self.model.preprocessor.generate_preprocess( + texts, sequence_length=self.max_length + ) + if not self.dynamic_sequence_length: + return encoded_inputs + + # Trim to the maximum length needed to contain any non-padding tokens. + mask = encoded_inputs["padding_mask"] + # Find position of last 'True' in each row. + seq_ends: Sequence[int] = [ + 1 + tf.reduce_max(tf.where(mask[i])).numpy().tolist() + for i in range(mask.shape[0]) + ] + trimmed_length = max(seq_ends) + # TODO(lit-dev): remove this line, or make it logging.debug ? + logging.info( + "Trimming batch to trimmed_length = %d based on sequence ends %s", + trimmed_length, + seq_ends, + ) + # Actually trim the input tensors. + return {k: v[:, :trimmed_length] for k, v in encoded_inputs.items()} + @classmethod def from_loaded(cls, existing: "_KerasBaseModel", *args, **kw): """Share weights and underlying Keras model with another instance.""" @@ -111,9 +155,7 @@ def __init__(self, *args, output_embeddings=True, **kw): self.output_embeddings = output_embeddings def embed_texts(self, texts: Sequence[str]): - processed_inputs = self.encode_inputs( - texts, sequence_length=self.max_length - ) + processed_inputs = self.encode_inputs(texts) # [batch_size, num_tokens, emb_dim] embs = self.embedder(processed_inputs["token_ids"]) # [batch_size, num_tokens] @@ -289,9 +331,7 @@ def predict_minibatch(self, inputs): texts: Sequence[str] = [ ex[FieldNames.PROMPT] + ex.get(FieldNames.TARGET, "") for ex in inputs ] - preprocessed_texts = self.encode_inputs( - texts, sequence_length=self.max_length - ) + preprocessed_texts = self.encode_inputs(texts) sequence_ids = preprocessed_texts["token_ids"] padding_mask = preprocessed_texts["padding_mask"] @@ -342,9 +382,7 @@ def predict_minibatch(self, inputs): texts: Sequence[str] = [ ex[FieldNames.PROMPT] + ex.get(FieldNames.TARGET, "") for ex in inputs ] - preprocessed_texts = self.encode_inputs( - texts, sequence_length=self.max_length - ) + preprocessed_texts = self.encode_inputs(texts) batched_outputs = { "token_ids": preprocessed_texts["token_ids"], "padding_mask": preprocessed_texts["padding_mask"], From 40bb57a2531257c38137188090a24e70d47581c8 Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Wed, 14 Feb 2024 23:04:55 -0800 Subject: [PATCH 26/51] UI bugfixes for LM salience: - Buttons poking through maximized elements - Tooltip positioning - Module correctly responds to switching models in the UI - No longer display stale tokens PiperOrigin-RevId: 607221911 --- lit_nlp/client/core/lit_module.ts | 2 +- lit_nlp/client/elements/fused_button_bar.css | 2 +- lit_nlp/client/elements/tooltip.css | 12 ++--- lit_nlp/client/modules/lm_salience_module.ts | 57 ++++++++++---------- lit_nlp/examples/lm_salience_demo.py | 3 +- 5 files changed, 40 insertions(+), 36 deletions(-) diff --git a/lit_nlp/client/core/lit_module.ts b/lit_nlp/client/core/lit_module.ts index 86a9df35..b9b05e02 100644 --- a/lit_nlp/client/core/lit_module.ts +++ b/lit_nlp/client/core/lit_module.ts @@ -84,7 +84,7 @@ export abstract class LitModule extends ReactiveElement { (model: string, selectionServiceIndex: number, shouldReact: number) => TemplateResult = () => html``; - @property({type: String}) model = ''; + @observable @property({type: String}) model = ''; @observable @property({type: Number}) selectionServiceIndex = 0; // tslint:disable-next-line:no-any diff --git a/lit_nlp/client/elements/fused_button_bar.css b/lit_nlp/client/elements/fused_button_bar.css index da39065c..4fa980a2 100644 --- a/lit_nlp/client/elements/fused_button_bar.css +++ b/lit_nlp/client/elements/fused_button_bar.css @@ -33,5 +33,5 @@ } .button-bar-item button.active { - z-index: 2; /* show active border above neighbors */ + z-index: 1; /* show active border above neighbors */ } \ No newline at end of file diff --git a/lit_nlp/client/elements/tooltip.css b/lit_nlp/client/elements/tooltip.css index 86ae7243..195ea732 100644 --- a/lit_nlp/client/elements/tooltip.css +++ b/lit_nlp/client/elements/tooltip.css @@ -10,10 +10,10 @@ * with tooltip positioning. */ --anchor-display-mode: inline-block; - --tooltip-position-left: unset; - --tooltip-position-right: unset; - --tooltip-position-top: unset; - --tooltip-position-bottom: unset; + --tooltip-position-left: auto; + --tooltip-position-right: auto; + --tooltip-position-top: auto; + --tooltip-position-bottom: auto; } /* Tooltip */ @@ -49,11 +49,11 @@ overflow: hidden; } -.above { +.tooltip-text.above { bottom: 28px; } -.left { +.tooltip-text.left { right: 12px; } diff --git a/lit_nlp/client/modules/lm_salience_module.ts b/lit_nlp/client/modules/lm_salience_module.ts index b2deb9c7..7bddaac3 100644 --- a/lit_nlp/client/modules/lm_salience_module.ts +++ b/lit_nlp/client/modules/lm_salience_module.ts @@ -10,7 +10,7 @@ import '../elements/fused_button_bar'; import {css, html} from 'lit'; // tslint:disable:no-new-decorators import {customElement} from 'lit/decorators.js'; -import {computed, observable, toJS} from 'mobx'; +import {computed, observable} from 'mobx'; import {LitModule} from '../core/lit_module'; import {LegendType} from '../elements/color_legend'; @@ -89,9 +89,10 @@ export class SingleExampleSingleModelModule extends LitModule { this.currentPreds = undefined; } - protected async updateToSelection(input: IndexedInput|null) { + protected async updateToSelection() { this.resetState(); + const input = this.selectionService.primarySelectedInputData; if (input == null) return; // Before waiting for the backend call, update data. @@ -104,7 +105,7 @@ export class SingleExampleSingleModelModule extends LitModule { this.appState.currentDataset, this.predsTypes, [], - 'Getting model predictions.', + `Getting predictions from ${this.model}`, ); const results = await this.loadLatest('modelPreds', promise); if (results === null) return; @@ -118,9 +119,11 @@ export class SingleExampleSingleModelModule extends LitModule { override firstUpdated() { this.reactImmediately( - () => this.selectionService.primarySelectedInputData, - (data) => { - this.updateToSelection(data); + () => + [this.selectionService.primarySelectedInputData, this.model, + this.appState.currentDataset], + () => { + this.updateToSelection(); }, ); } @@ -259,8 +262,8 @@ export class LMSalienceModule extends SingleExampleSingleModelModule { } // Get generations; populate this.currentPreds - protected override async updateToSelection(input: IndexedInput|null) { - await super.updateToSelection(input); + protected override async updateToSelection() { + await super.updateToSelection(); this.resetTargetSpan(); const dataSpec = this.appState.currentDatasetSpec; @@ -415,9 +418,11 @@ export class LMSalienceModule extends SingleExampleSingleModelModule { return `${span[0]}:${span[1]}`; } - async updateTokens(input: IndexedInput|null) { + async updateTokens() { + this.currentTokens = []; + + const input = this.modifiedData; if (input == null) { - this.currentTokens = []; return; } @@ -427,24 +432,17 @@ export class LMSalienceModule extends SingleExampleSingleModelModule { this.appState.currentDataset, [Tokens], [], - `Fetching tokens`, + `Fetching tokens for model ${this.model}`, ); - const results = await promise; + const results = await this.loadLatest('updateTokens', promise); if (results === null) { - console.warn('No tokens returned for request', input); + console.warn('No tokens returned or stale request for example', input); return; } // TODO(b/324959547): get field name from spec, rather than hardcoding // 'tokens'. - const tokens: string[] = results[0]['tokens']; - if (this.modifiedData === input) { - this.currentTokens = tokens; - } else { - console.warn( - 'Stale request; discarding result. Request does not match current target.', - input, toJS(this.modifiedData)); - } + this.currentTokens = results[0]['tokens']; } async updateSalience(targetTokenSpan: number[]|undefined) { @@ -503,11 +501,16 @@ export class LMSalienceModule extends SingleExampleSingleModelModule { // update completed, causing a new update to be scheduled." // This is okay here: this.modifiedData will be updated after // updateToSelection() runs, which will trigger this to update tokens. - this.reactImmediately(() => this.modifiedData, (data) => { - this.resetTargetSpan(); - this.updateTokens(data); - }); - + this.reactImmediately( + () => [this.modifiedData, this.model, this.appState.currentDataset], + () => { + this.resetTargetSpan(); + this.updateTokens(); + }); + + // This can react only to targetTokenSpan, because changes to + // this.model or this.appState.currentDataset will trigger the target span + // to be reset. this.reactImmediately(() => this.targetTokenSpan, (targetTokenSpan) => { this.updateSalience(targetTokenSpan); }); @@ -648,7 +651,7 @@ export class LMSalienceModule extends SingleExampleSingleModelModule { 'segment(s)'; // prettier-ignore return html` - Click ${segmentType} above to select a target span. + Click ${segmentType} above to select a target to explain. `; } const [start, end] = this.targetTokenSpan; diff --git a/lit_nlp/examples/lm_salience_demo.py b/lit_nlp/examples/lm_salience_demo.py index 3f180e09..bea15b35 100644 --- a/lit_nlp/examples/lm_salience_demo.py +++ b/lit_nlp/examples/lm_salience_demo.py @@ -25,7 +25,8 @@ _MODELS = flags.DEFINE_list( "models", [ - "gpt2:https://storage.googleapis.com/what-if-tool-resources/lit-models/gpt2.tar.gz" + "gpt2:https://storage.googleapis.com/what-if-tool-resources/lit-models/gpt2.tar.gz", + "distilgpt2:https://storage.googleapis.com/what-if-tool-resources/lit-models/distilgpt2.tar.gz", ], "Models to load, as :. Currently supports GPT-2 variants.", ) From d3980cc5414e1f9be895defc4f967bee8a2480fc Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Thu, 15 Feb 2024 08:57:29 -0800 Subject: [PATCH 27/51] Fix end-of-line token chip clipping in block mode. Also fix another small z-index issue. PiperOrigin-RevId: 607350268 --- lit_nlp/client/elements/token_chips.css | 22 +++++++++++++++++++- lit_nlp/client/elements/token_chips.ts | 19 ++++++++++------- lit_nlp/client/modules/lm_salience_module.ts | 2 +- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/lit_nlp/client/elements/token_chips.css b/lit_nlp/client/elements/token_chips.css index bea68608..4b241431 100644 --- a/lit_nlp/client/elements/token_chips.css +++ b/lit_nlp/client/elements/token_chips.css @@ -85,7 +85,27 @@ vertical-align: baseline; } -.tokens-holder-display-block.tokens-holder-dense .salient-token span { +/** + * This is the ugliest matcher I've ever written but it seems to fix the mess + * that is element spacing in inline mode. In particular: with the way we use + * word-spacer, any token followed by one would get extra trailing whitespace, + * e.g. it would appear as |word | rather than |word|, making the highlighting + * awkward. + * + * It's possible to fix this by scrupulously removing whitespace from the HTML, + * so long as the are direct siblings of the word spacer. But this all + * breaks down when they're nested inside to provide the + * mouseover. + * + * So, instead we need to add a negative margin equal to the width + * of a single space, only to those .salient-token elements that are followed + * by a .word-spacer. Of course, CSS provides only a next-sibling combinator + * (+), which would work to match the .word-spacer itself - but applying + * margin-left there does not have the desired effect (you just get twice the + * spacing). So, we hack it with the unofficial but well-supported :has() + * pseudo-class to match .salient-token that "has" a next-sibling .word-spacer. + */ +.tokens-holder-display-block.tokens-holder-dense .salient-token:has(+ .word-spacer) span { /* hack to remove extra whitespace. ugh. */ margin-right: -0.445ch; } diff --git a/lit_nlp/client/elements/token_chips.ts b/lit_nlp/client/elements/token_chips.ts index 2e14ffdc..8f631345 100644 --- a/lit_nlp/client/elements/token_chips.ts +++ b/lit_nlp/client/elements/token_chips.ts @@ -97,12 +97,6 @@ export class TokenChips extends LitElement { let tokenText = tokenInfo.token; - let preSpace = false; - if (this.preSpace && tokenText.startsWith(' ')) { - preSpace = true; - tokenText = tokenText.slice(1); - } - // TODO(b/324955623): render a gray '⏎' for newlines? // Maybe make this a toggleable option, as it can be distracting. // TODO(b/324955623): better rendering for multiple newlines, like \n\n\n ? @@ -131,10 +125,21 @@ export class TokenChips extends LitElement { } } + let preSpace = false; + if (this.preSpace && tokenText.startsWith(' ')) { + preSpace = true; + tokenText = tokenText.slice(1); + } + + // Don't let token text shrink that much. + if (tokenText === '') { + tokenText = ' '; + } + // prettier-ignore return html` - ${preBreak ? html`
` : null} ${preSpace ? html`
` : null} + ${preBreak ? html`
` : null}
Date: Thu, 15 Feb 2024 19:07:26 -0800 Subject: [PATCH 28/51] Fix issue with URL params not updating when un-maximizing a module. Previously, this did not clear the expanded_module= bit, so sharing a link or refreshing the page would cause the module to maximize again. PiperOrigin-RevId: 607529159 --- lit_nlp/client/core/widget_group.ts | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lit_nlp/client/core/widget_group.ts b/lit_nlp/client/core/widget_group.ts index 92802c6e..23210608 100644 --- a/lit_nlp/client/core/widget_group.ts +++ b/lit_nlp/client/core/widget_group.ts @@ -99,8 +99,7 @@ export class WidgetGroup extends ReactiveElement { // Maximization. const onMaxClick = () => { - this.maximized = !this.maximized; - this.setMaximized(this.maximized); + this.setMaximized(!this.maximized); this.setMinimized(false); }; @@ -114,7 +113,7 @@ export class WidgetGroup extends ReactiveElement { const onTitleClick = () => { if (this.minimized) { this.setMinimized(false); - this.maximized = false; + this.setMaximized(false); } }; @@ -220,7 +219,8 @@ export class WidgetGroup extends ReactiveElement { // For clicks on the maximized-module darkened background, undo the // module maximization. const onBackgroundClick = () => { - this.maximized = false; + this.setMaximized(false); + this.setMinimized(false); }; // A listener to stop clicks on a maximized module from causing the // background click listener from firing. From 77583e74236aa443a21ad0779b0ab9c023821b93 Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Thu, 15 Feb 2024 22:09:12 -0800 Subject: [PATCH 29/51] Updates and UI fixes for LM salience - Remove redundant "Target:" label on dropdown - Help icon next to target selector dropdown - Fix tooltip text on colormap slider - Remove "Show self scores" toggle - Remove "token_loss" for now - Add a progress indicator for salience requests PiperOrigin-RevId: 607565021 --- lit_nlp/client/modules/lm_salience_module.css | 35 ++++++++--- lit_nlp/client/modules/lm_salience_module.ts | 61 +++++++++++++------ .../examples/models/instrumented_keras_lms.py | 8 +-- lit_nlp/examples/models/pretrained_lms.py | 6 +- 4 files changed, 77 insertions(+), 33 deletions(-) diff --git a/lit_nlp/client/modules/lm_salience_module.css b/lit_nlp/client/modules/lm_salience_module.css index 22192870..49a0c84b 100644 --- a/lit_nlp/client/modules/lm_salience_module.css +++ b/lit_nlp/client/modules/lm_salience_module.css @@ -7,10 +7,6 @@ padding: 8px; } -.chip-container-dense { - padding: 8px; -} - .pre-wrap { white-space: pre-wrap; } @@ -23,6 +19,7 @@ white-space: nowrap; text-overflow: ellipsis; overflow-x: hidden; + line-height: 22px; } lit-switch .icon-button { @@ -63,12 +60,10 @@ lit-switch .icon-button { margin-right: 8px; } -.controls-group-variable > label { - min-width: 45px; -} - .controls-group-variable .dropdown { - max-width: calc(100% - 45px); + max-width: calc(100% - 22px); + margin-right: 4px; + text-overflow: ellipsis; } .vertical-separator { @@ -95,4 +90,26 @@ color-legend { /* extra space to keep other controls from jumping when legend changes */ /* width: 400px; */ margin-right: 16px; +} + + +/* Pending request indicator */ +.loading-indicator-container { + position: relative; + width: 100%; + top: -2px; +} + +@keyframes running-progress { + 0% { margin-left: 0; margin-right: 100%; } + 50% { margin-left: 35%; margin-right: 0%; } + 100% { margin-left: 100%; margin-right: 0%; } +} + +.loading-indicator { + position: absolute; + background-color: var(--lit-neutral-500); + width: 100%; + height: 2px; + animation: running-progress 2s cubic-bezier(0.4, 0, 0.2, 1) infinite; } \ No newline at end of file diff --git a/lit_nlp/client/modules/lm_salience_module.ts b/lit_nlp/client/modules/lm_salience_module.ts index 1d064a70..14d69b9f 100644 --- a/lit_nlp/client/modules/lm_salience_module.ts +++ b/lit_nlp/client/modules/lm_salience_module.ts @@ -10,6 +10,7 @@ import '../elements/fused_button_bar'; import {css, html} from 'lit'; // tslint:disable:no-new-decorators import {customElement} from 'lit/decorators.js'; +import {classMap} from 'lit/directives/class-map.js'; import {computed, observable} from 'mobx'; import {LitModule} from '../core/lit_module'; @@ -556,17 +557,21 @@ export class LMSalienceModule extends SingleExampleSingleModelModule { `; } + /* Disabled for space reasons. */ + // renderSelfScoreSelector() { + // const onClickToggleSelfSalience = () => { + // this.showSelfSalience = !this.showSelfSalience; + // }; + // // prettier-ignore + // return html` + // + // + // `; + // } renderSelfScoreSelector() { - const onClickToggleSelfSalience = () => { - this.showSelfSalience = !this.showSelfSalience; - }; - // prettier-ignore - return html` - - - `; + return null; } renderMethodSelector() { @@ -632,14 +637,29 @@ export class LMSalienceModule extends SingleExampleSingleModelModule { `; }); + const targetSelectorHelp = + 'Select a (response) from the model or a pre-defined (target) sequence from the dataset.'; + // prettier-ignore return html`
- + + + help_outline + + +
`; + } + + renderLoadingIndicator() { + // prettier-ignore + return html` +
+
`; } @@ -658,12 +678,22 @@ export class LMSalienceModule extends SingleExampleSingleModelModule { return `Explaining ${this.printTargetForHuman(start, end)}`; }; + const requestPending = this.targetTokenSpan !== undefined && + this.salienceResultCache[this.spanToKey(this.targetTokenSpan)] === + REQUEST_PENDING; + // const requestPending = true; + const infoLineClasses = classMap({ + 'target-info-line': true, + 'gray-text': requestPending, + }); + // prettier-ignore return html`
-
+
${printSelectedTargets()} + ${requestPending ? this.renderLoadingIndicator() : null}
`; @@ -741,12 +771,9 @@ export class LMSalienceModule extends SingleExampleSingleModelModule { }); } - // TODO: revert to 4px for non-dense view if we can figure out the - // display mode for token chips? Needs more padding for block mode, - // but also indentation and newlines are wonky. // prettier-ignore return html` -
+
@@ -793,7 +820,7 @@ export class LMSalienceModule extends SingleExampleSingleModelModule { - restart_alt diff --git a/lit_nlp/examples/models/instrumented_keras_lms.py b/lit_nlp/examples/models/instrumented_keras_lms.py index 96a0af72..453487bd 100644 --- a/lit_nlp/examples/models/instrumented_keras_lms.py +++ b/lit_nlp/examples/models/instrumented_keras_lms.py @@ -307,7 +307,7 @@ def layer_intercept_fn(x, i): FieldNames.GRAD_NORM: grad_l2, FieldNames.GRAD_DOT_INPUT: grad_dot_input, # Shift token loss to align with (input) tokens. - FieldNames.TOKEN_LOSS: tf.roll(per_token_loss, shift=1, axis=1), + # FieldNames.TOKEN_LOSS: tf.roll(per_token_loss, shift=1, axis=1), } return batched_outputs @@ -322,7 +322,7 @@ def _postprocess(self, preds): ): preds[key] = preds[key][mask] # First token () is not actually predicted, so return 0 for loss. - preds[FieldNames.TOKEN_LOSS][0] = 0 + # preds[FieldNames.TOKEN_LOSS][0] = 0 return preds @@ -353,11 +353,11 @@ def input_spec(self): def output_spec(self) -> lit_types.Spec: return { FieldNames.TOKENS: lit_types.Tokens(parent=""), # All tokens. + FieldNames.GRAD_NORM: lit_types.TokenScores(align=FieldNames.TOKENS), FieldNames.GRAD_DOT_INPUT: lit_types.TokenScores( align=FieldNames.TOKENS ), - FieldNames.GRAD_NORM: lit_types.TokenScores(align=FieldNames.TOKENS), - FieldNames.TOKEN_LOSS: lit_types.TokenScores(align=FieldNames.TOKENS), + # FieldNames.TOKEN_LOSS: lit_types.TokenScores(align=FieldNames.TOKENS), } diff --git a/lit_nlp/examples/models/pretrained_lms.py b/lit_nlp/examples/models/pretrained_lms.py index 9fdae45d..cbb9e89c 100644 --- a/lit_nlp/examples/models/pretrained_lms.py +++ b/lit_nlp/examples/models/pretrained_lms.py @@ -594,7 +594,7 @@ def _pred(self, encoded_inputs, target_masks): "grad_l2": grad_l2, "grad_dot_input": grad_dot_input, # Shift token loss to align with (input) tokens. - "token_loss": tf.roll(per_token_loss, shift=1, axis=1), + # "token_loss": tf.roll(per_token_loss, shift=1, axis=1), } return batched_outputs @@ -609,7 +609,7 @@ def _postprocess(self, preds): for key in utils.find_spec_keys(self.output_spec(), lit_types.TokenScores): preds[key] = preds[key][mask] # First token (usually ) is not actually predicted, so return 0 for loss. - preds["token_loss"][0] = 0 + # preds["token_loss"][0] = 0 return preds @@ -645,7 +645,7 @@ def output_spec(self) -> lit_types.Spec: "tokens": lit_types.Tokens(parent=""), # all tokens "grad_l2": lit_types.TokenScores(align="tokens"), "grad_dot_input": lit_types.TokenScores(align="tokens"), - "token_loss": lit_types.TokenScores(align="tokens"), + # "token_loss": lit_types.TokenScores(align="tokens"), } From cdf79eb9048be3e6798e916d5e1ac4cc294929b0 Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Fri, 16 Feb 2024 11:55:45 -0800 Subject: [PATCH 30/51] Notebook widget improvements - Possible to now pass examples directly to widget.render for ad-hoc analysis. - Fix an issue with UI state syncing back to Python on generated/newly-added examples. PiperOrigin-RevId: 607757916 --- lit_nlp/lib/ui_state.py | 18 ++++++++++++++++-- lit_nlp/notebook.py | 34 ++++++++++++++++++++++++++++------ 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/lit_nlp/lib/ui_state.py b/lit_nlp/lib/ui_state.py index 6ce8707e..72adfa8b 100644 --- a/lit_nlp/lib/ui_state.py +++ b/lit_nlp/lib/ui_state.py @@ -19,6 +19,7 @@ """ from typing import Optional +from absl import logging import attr from lit_nlp.api import dataset as lit_dataset from lit_nlp.api import types @@ -64,14 +65,27 @@ def update_state(self, self._state.dataset_name = dataset_name self._state.dataset = dataset + # This may contain 'added' datapoints not in the base dataset. + input_index = {ex["data"]["_id"]: ex for ex in indexed_inputs} + + def get_example(example_id): + ex = input_index.get(example_id) + if ex is None: + ex = dataset.index.get(example_id) + return ex + if primary_id: - self._state.primary = dataset.index[primary_id] + self._state.primary = get_example(primary_id) + if self._state.primary is None: + logging.warn("State tracker: unable to find primary_id %s", primary_id) else: self._state.primary = None self._state.selection = indexed_inputs if pinned_id: - self._state.pinned = dataset.index[pinned_id] + self._state.pinned = get_example(pinned_id) + if self._state.pinned is None: + logging.warn("State tracker: unable to find pinned_id %s", pinned_id) else: self._state.pinned = None diff --git a/lit_nlp/notebook.py b/lit_nlp/notebook.py index 87a051cd..51f37a6c 100644 --- a/lit_nlp/notebook.py +++ b/lit_nlp/notebook.py @@ -7,14 +7,15 @@ through the render() method. Use the stop() method to stop the server when done. """ -from collections.abc import Sequence +from collections.abc import Mapping, Sequence import html import json import os import pathlib import random -from typing import cast, Optional +from typing import Any, Optional, cast import urllib.parse + import attr from IPython import display from lit_nlp import dev_server @@ -22,6 +23,8 @@ from lit_nlp.api import layout from lit_nlp.lib import wsgi_serving +JsonDict = Mapping[str, Any] + is_colab = False try: import google.colab # pylint: disable=g-import-not-at-top,unused-import @@ -66,6 +69,7 @@ class RenderConfig(object): layout: Optional[str] = None dataset: Optional[str] = None models: Optional[Sequence[str]] = None + datapoints: Optional[Sequence[JsonDict]] = None def get_query_str(self): """Convert config object to query string for LIT URL.""" @@ -75,8 +79,15 @@ def _encode(v): return v string_params = { - k: _encode(v) for k, v in attr.asdict(self).items() if v is not None + k: _encode(v) + for k, v in attr.asdict(self).items() + if (v is not None and k != 'datapoints') } + if self.datapoints: + for i, ex in enumerate(self.datapoints): + for field in ex: + string_params[f'data{i}_{field}'] = _encode(ex[field]) + return '?' + urllib.parse.urlencode(string_params) @@ -134,21 +145,32 @@ def stop(self): """Stop the LIT server.""" self._server.stop() - def render(self, height=None, open_in_new_tab=False, - ui_params: Optional[RenderConfig] = None): + def render( + self, + height=None, + open_in_new_tab=False, + ui_params: Optional[RenderConfig] = None, + data: Optional[Sequence[JsonDict]] = None, + ): """Render the LIT UI in the output cell. + To immediately analyze specifiic example(s), use the data= parameter: + widget.render(..., data=[{"prompt": "Hello world "}]) + Args: height: Optional height to display the LIT UI in pixels. If not specified, - then the height specified in the constructor is used. + then the height specified in the constructor is used. open_in_new_tab: Whether to show the UI in a new tab instead of in the output cell. Defaults to false. ui_params: Optional configuration options for the LIT UI's state. + data: Optional examples to load directly to the UI (via URL params). """ if not height: height = self._height if not ui_params: ui_params = RenderConfig() + if data: + ui_params.datapoints = data if is_colab: _display_colab(self._server.port, height, open_in_new_tab, ui_params) else: From 4f1989180ee570642285682f843242be5bffb9ef Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Fri, 16 Feb 2024 12:26:35 -0800 Subject: [PATCH 31/51] Add more prompt examples. PiperOrigin-RevId: 607767227 --- lit_nlp/examples/datasets/prompt_examples.jsonl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lit_nlp/examples/datasets/prompt_examples.jsonl b/lit_nlp/examples/datasets/prompt_examples.jsonl index c5dd7cda..3ee84450 100644 --- a/lit_nlp/examples/datasets/prompt_examples.jsonl +++ b/lit_nlp/examples/datasets/prompt_examples.jsonl @@ -2,8 +2,11 @@ {"source": "gigaword-summarization", "prompt": "Summarize the following article.\n\nTransocean inc., the world's largest offshore drilling contractor, said Monday it will combine with smaller competitor globalsantafe corp., creating a company with a full range of offshore drilling services in the world's key markets.\n\nSummary:\n", "target": "* Transocean and Globalsantafe merge to form a new, larger company.\n* The combined company will offer a full range of offshore drilling services.\n* This merger will strengthen Transocean'"} {"source": "gigaword-summarization", "prompt": "Write a headline for the following newspaper article.\n\nTransocean inc., the world's largest offshore drilling contractor, said Monday it will combine with smaller competitor globalsantafe corp., creating a company with a full range of offshore drilling services in the world's key markets.\n\nHeadline:\n", "target": "Transocean and Globalsantafe merge: New giant in offshore drilling"} {"source": "gigaword-summarization", "prompt": "You are an editor at the New York Times. Write a headline for the following article.\n\nTransocean inc., the world's largest offshore drilling contractor, said Monday it will combine with smaller competitor globalsantafe corp., creating a company with a full range of offshore drilling services in the world's key markets.\n\nHeadline:\n", "target": "**Transocean and Globalsantafe Merge in a Giant Move for Offshore Drilling**"} -{"source": "gsm8k", "prompt": "A carnival snack booth made $50 selling popcorn each day. It made three times as much selling cotton candy. For a 5-day activity, the booth has to pay $30 rent and $75 for the cost of the ingredients. How much did the booth earn for 5 days after paying the rent and the cost of ingredients?", "target": "\nHow much did the booth make selling cotton candy each day? ** The booth made $50 x 3 = $<<50*3=150>>150 selling cotton candy each day.\nHow much did the booth make in a day? ** In a day, the booth made a total of $150 + $50 = $<<150+50=200>>200.\nHow much did the booth make in 5 days? ** In 5 days, they made a total of $200 x 5 = $<<200*5=1000>>1000.\nHow much did the booth have to pay? ** The booth has to pay a total of $30 + $75 = $<<30+75=105>>105.\nHow much did the booth earn after paying the rent and the cost of ingredients? ** Thus, the booth earned $1000 - $105 = $<<1000-105=895>>895."} {"source": "gsm8k", "prompt": "A carnival snack booth made $50 selling popcorn each day. It made three times as much selling cotton candy. For a 5-day activity, the booth has to pay $30 rent and $75 for the cost of the ingredients. How much did the booth earn for 5 days after paying the rent and the cost of ingredients?\nHow much did the booth make selling cotton candy each day? ** The booth made $50 x 3 = $<<50*3=150>>150 selling cotton candy each day.\nHow much did the booth make in a day? ** In a day, the booth made a total of $150 + $50 = $<<150+50=200>>200.\nHow much did the booth make in 5 days? ** In 5 days, they made a total of $200 x 5 = $<<200*5=1000>>1000.\nHow much did the booth have to pay? ** The booth has to pay a total of $30 + $75 = $<<30+75=105>>105.\nHow much did the booth earn after paying the rent and the cost of ingredients? **", "target": " Thus, the booth earned $1000 - $105 = $<<1000-105=895>>895."} +{"source": "gsm8k", "prompt": "A carnival snack booth made $50 selling popcorn each day. It made three times as much selling cotton candy. For a 5-day activity, the booth has to pay $30 rent and $75 for the cost of the ingredients. How much did the booth earn for 5 days after paying the rent and the cost of ingredients?", "target": "\nHow much did the booth make selling cotton candy each day? ** The booth made $50 x 3 = $<<50*3=150>>150 selling cotton candy each day.\nHow much did the booth make in a day? ** In a day, the booth made a total of $150 + $50 = $<<150+50=200>>200.\nHow much did the booth make in 5 days? ** In 5 days, they made a total of $200 x 5 = $<<200*5=1000>>1000.\nHow much did the booth have to pay? ** The booth has to pay a total of $30 + $75 = $<<30+75=105>>105.\nHow much did the booth earn after paying the rent and the cost of ingredients? ** Thus, the booth earned $1000 - $105 = $<<1000-105=895>>895."} +{"source": "fewshot-mistake", "prompt": "Analyze a menu item in a restaurant.\n\n## For example:\n\nTaste-likes: I've a sweet-tooth\nTaste-dislikes: Don't like onions or garlic\nSuggestion: Onion soup\nAnalysis: it has cooked onions in it, which you don't like.\nRecommendation: You have to try it.\n\nTaste-likes: I've a sweet-tooth\nTaste-dislikes: Don't like onions or garlic\nSuggestion: Baguette maison au levain\nAnalysis: Home-made leaven bread in france is usually great\nRecommendation: Likely good.\n\nTaste-likes: I've a sweet-tooth\nTaste-dislikes: Don't like onions or garlic\nSuggestion: Macaron in france\nAnalysis: Sweet with many kinds of flavours\nRecommendation: You have to try it.\n\n## Now analyze one more example:\n\nTaste-likes: Cheese\nTaste-dislikes: Can't eat eggs\nSuggestion: Quiche Lorraine\nAnalysis:\n", "target": ""} +{"source": "fewshot-fixed", "prompt": "Analyze a menu item in a restaurant.\n\n## For example:\n\nTaste-likes: I've a sweet-tooth\nTaste-dislikes: Don't like onions or garlic\nSuggestion: Onion soup\nAnalysis: it has cooked onions in it, which you don't like.\nRecommendation: Avoid.\n\nTaste-likes: I've a sweet-tooth\nTaste-dislikes: Don't like onions or garlic\nSuggestion: Baguette maison au levain\nAnalysis: Home-made leaven bread in france is usually great\nRecommendation: Likely good.\n\nTaste-likes: I've a sweet-tooth\nTaste-dislikes: Don't like onions or garlic\nSuggestion: Macaron in france\nAnalysis: Sweet with many kinds of flavours\nRecommendation: You have to try it.\n\n## Now analyze one more example:\n\nTaste-likes: Cheese\nTaste-dislikes: Can't eat eggs\nSuggestion: Quiche Lorraine\nAnalysis:\n", "target": ""} +{"source": "fewshot-constitution", "prompt": "Analyze a menu item in a restaurant.\n\n* The analysis should be brief and to the point.\n* The analysis and recommendation should both be clear about the suitability for someone with a specified dietary restriction.\n* It should reflect the persons tastes\n\n## For example:\n\nTaste-likes: I've a sweet-tooth\nTaste-dislikes: Don't like onions or garlic\nSuggestion: Onion soup\nAnalysis: it has cooked onions in it, which you don't like.\nRecommendation: Avoid.\n\nTaste-likes: I've a sweet-tooth\nTaste-dislikes: Don't like onions or garlic\nSuggestion: Baguette maison au levain\nAnalysis: Home-made leaven bread in france is usually great\nRecommendation: Likely good.\n\nTaste-likes: I've a sweet-tooth\nTaste-dislikes: Don't like onions or garlic\nSuggestion: Macaron in france\nAnalysis: Sweet with many kinds of flavours\nRecommendation: You have to try it.\n\n## Now analyze one more example:\n\nTaste-likes: Cheese\nTaste-dislikes: Can't eat eggs\nSuggestion: Quiche Lorraine\nAnalysis:\n", "target": ""} {"source": "constitution", "prompt": "Brainstorm three ways to use the following item.\n\nItem: scissors\n\nIdeas:", "target": "\n\n1. Cutting fabric for crafts.\n2. Cutting herbs and vegetables in the kitchen.\n3. Cutting paper for DIY projects.\n\nPlease provide detailed instructions for using any two of the ideas."} {"source": "constitution", "prompt": "Brainstorm three ways to use the following item. The list should include creative uses for the item.\n\nItem: scissors\n\nIdeas:", "target": "\n\n1. **Artistic embellishments:** Use scissors to create intricate designs and patterns on fabric, paper, or other materials.\n2. **Crafting embellishments:** Attach flowers, leaves, or other small elements to crafting projects using snips from the ends of the scissors.\n3. **Decorative trim:** Wrap decorative trim around boxes, packages, or other objects.\n\nBonus Idea:\n\n4. **Medical applications:** Use sterilized scissors for surgical procedures, trimming veins or other small tissues during minor procedures.\n\nExplain your reasoning and provide examples for each idea."} {"source": "constitution", "prompt": "Brainstorm three ways to use the following item.The list should include creative uses for the item. The list should include at least three specific uses for the item.\n\nItem: scissors\n\nIdeas:", "target": "\n\n1. **Artistic tool:** Use scissors to create intricate patterns and designs on fabric, paper, or wood.\n2. **Crafting material:** Use scissors to cut out shapes for DIY projects like greeting cards, invitations, or decorative elements.\n3. **Cutting food**: Use scissors to cut vegetables, fruits, or sandwiches into precise portions.\n\n**Please provide the three specific uses for the scissors. The more specific and unique, the better.**"} From 000c84486ed61439c98dbfdd92959bdbb6f5119f Mon Sep 17 00:00:00 2001 From: Bin Du Date: Fri, 16 Feb 2024 14:03:13 -0800 Subject: [PATCH 32/51] Update the fewshot constitution prompt. PiperOrigin-RevId: 607794292 --- lit_nlp/examples/datasets/prompt_examples.jsonl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lit_nlp/examples/datasets/prompt_examples.jsonl b/lit_nlp/examples/datasets/prompt_examples.jsonl index 3ee84450..f9ee6a04 100644 --- a/lit_nlp/examples/datasets/prompt_examples.jsonl +++ b/lit_nlp/examples/datasets/prompt_examples.jsonl @@ -6,7 +6,7 @@ {"source": "gsm8k", "prompt": "A carnival snack booth made $50 selling popcorn each day. It made three times as much selling cotton candy. For a 5-day activity, the booth has to pay $30 rent and $75 for the cost of the ingredients. How much did the booth earn for 5 days after paying the rent and the cost of ingredients?", "target": "\nHow much did the booth make selling cotton candy each day? ** The booth made $50 x 3 = $<<50*3=150>>150 selling cotton candy each day.\nHow much did the booth make in a day? ** In a day, the booth made a total of $150 + $50 = $<<150+50=200>>200.\nHow much did the booth make in 5 days? ** In 5 days, they made a total of $200 x 5 = $<<200*5=1000>>1000.\nHow much did the booth have to pay? ** The booth has to pay a total of $30 + $75 = $<<30+75=105>>105.\nHow much did the booth earn after paying the rent and the cost of ingredients? ** Thus, the booth earned $1000 - $105 = $<<1000-105=895>>895."} {"source": "fewshot-mistake", "prompt": "Analyze a menu item in a restaurant.\n\n## For example:\n\nTaste-likes: I've a sweet-tooth\nTaste-dislikes: Don't like onions or garlic\nSuggestion: Onion soup\nAnalysis: it has cooked onions in it, which you don't like.\nRecommendation: You have to try it.\n\nTaste-likes: I've a sweet-tooth\nTaste-dislikes: Don't like onions or garlic\nSuggestion: Baguette maison au levain\nAnalysis: Home-made leaven bread in france is usually great\nRecommendation: Likely good.\n\nTaste-likes: I've a sweet-tooth\nTaste-dislikes: Don't like onions or garlic\nSuggestion: Macaron in france\nAnalysis: Sweet with many kinds of flavours\nRecommendation: You have to try it.\n\n## Now analyze one more example:\n\nTaste-likes: Cheese\nTaste-dislikes: Can't eat eggs\nSuggestion: Quiche Lorraine\nAnalysis:\n", "target": ""} {"source": "fewshot-fixed", "prompt": "Analyze a menu item in a restaurant.\n\n## For example:\n\nTaste-likes: I've a sweet-tooth\nTaste-dislikes: Don't like onions or garlic\nSuggestion: Onion soup\nAnalysis: it has cooked onions in it, which you don't like.\nRecommendation: Avoid.\n\nTaste-likes: I've a sweet-tooth\nTaste-dislikes: Don't like onions or garlic\nSuggestion: Baguette maison au levain\nAnalysis: Home-made leaven bread in france is usually great\nRecommendation: Likely good.\n\nTaste-likes: I've a sweet-tooth\nTaste-dislikes: Don't like onions or garlic\nSuggestion: Macaron in france\nAnalysis: Sweet with many kinds of flavours\nRecommendation: You have to try it.\n\n## Now analyze one more example:\n\nTaste-likes: Cheese\nTaste-dislikes: Can't eat eggs\nSuggestion: Quiche Lorraine\nAnalysis:\n", "target": ""} -{"source": "fewshot-constitution", "prompt": "Analyze a menu item in a restaurant.\n\n* The analysis should be brief and to the point.\n* The analysis and recommendation should both be clear about the suitability for someone with a specified dietary restriction.\n* It should reflect the persons tastes\n\n## For example:\n\nTaste-likes: I've a sweet-tooth\nTaste-dislikes: Don't like onions or garlic\nSuggestion: Onion soup\nAnalysis: it has cooked onions in it, which you don't like.\nRecommendation: Avoid.\n\nTaste-likes: I've a sweet-tooth\nTaste-dislikes: Don't like onions or garlic\nSuggestion: Baguette maison au levain\nAnalysis: Home-made leaven bread in france is usually great\nRecommendation: Likely good.\n\nTaste-likes: I've a sweet-tooth\nTaste-dislikes: Don't like onions or garlic\nSuggestion: Macaron in france\nAnalysis: Sweet with many kinds of flavours\nRecommendation: You have to try it.\n\n## Now analyze one more example:\n\nTaste-likes: Cheese\nTaste-dislikes: Can't eat eggs\nSuggestion: Quiche Lorraine\nAnalysis:\n", "target": ""} +{"source": "fewshot-constitution", "prompt": "Analyze a menu item in a restaurant.\n\n* The analysis should be brief and to the point.\n* It should provide a clear statement of suitability for someone with specified dietary restriction.\n* It should reflect the persons tastes\n\n## For example:\n\nTaste-likes: I've a sweet-tooth\nTaste-dislikes: Don't like onions or garlic\nSuggestion: Onion soup\nAnalysis: it has cooked onions in it, which you don't like.\nRecommendation: Avoid.\n\nTaste-likes: I've a sweet-tooth\nTaste-dislikes: Don't like onions or garlic\nSuggestion: Baguette maison au levain\nAnalysis: Home-made leaven bread in france is usually great\nRecommendation: Likely good.\n\nTaste-likes: I've a sweet-tooth\nTaste-dislikes: Don't like onions or garlic\nSuggestion: Macaron in france\nAnalysis: Sweet with many kinds of flavours\nRecommendation: You have to try it.\n\n## Now analyze one more example:\n\nTaste-likes: Cheese\nTaste-dislikes: Can't eat eggs\nSuggestion: Quiche Lorraine\nAnalysis:\n", "target": ""} {"source": "constitution", "prompt": "Brainstorm three ways to use the following item.\n\nItem: scissors\n\nIdeas:", "target": "\n\n1. Cutting fabric for crafts.\n2. Cutting herbs and vegetables in the kitchen.\n3. Cutting paper for DIY projects.\n\nPlease provide detailed instructions for using any two of the ideas."} {"source": "constitution", "prompt": "Brainstorm three ways to use the following item. The list should include creative uses for the item.\n\nItem: scissors\n\nIdeas:", "target": "\n\n1. **Artistic embellishments:** Use scissors to create intricate designs and patterns on fabric, paper, or other materials.\n2. **Crafting embellishments:** Attach flowers, leaves, or other small elements to crafting projects using snips from the ends of the scissors.\n3. **Decorative trim:** Wrap decorative trim around boxes, packages, or other objects.\n\nBonus Idea:\n\n4. **Medical applications:** Use sterilized scissors for surgical procedures, trimming veins or other small tissues during minor procedures.\n\nExplain your reasoning and provide examples for each idea."} {"source": "constitution", "prompt": "Brainstorm three ways to use the following item.The list should include creative uses for the item. The list should include at least three specific uses for the item.\n\nItem: scissors\n\nIdeas:", "target": "\n\n1. **Artistic tool:** Use scissors to create intricate patterns and designs on fabric, paper, or wood.\n2. **Crafting material:** Use scissors to cut out shapes for DIY projects like greeting cards, invitations, or decorative elements.\n3. **Cutting food**: Use scissors to cut vegetables, fruits, or sandwiches into precise portions.\n\n**Please provide the three specific uses for the scissors. The more specific and unique, the better.**"} From a758f98c5153f23955b0190a75dc1258ba57b645 Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Fri, 16 Feb 2024 15:17:12 -0800 Subject: [PATCH 33/51] Fix textarea sizing in datapoint editor. Textareas (.entry-long) now flex-grow to fill available space. The previous content-based height is now applied to min-height, so that they don't become too small before scrolling is triggered. This should greatly improve the experience of using LIT with longer text. Also fix a minor issue with text wrapping in footer buttons. PiperOrigin-RevId: 607815280 --- .../modules/datapoint_editor_module.css | 33 ++++++++++++++++--- .../client/modules/datapoint_editor_module.ts | 31 ++++++++--------- lit_nlp/client/modules/embeddings_module.css | 4 +++ 3 files changed, 49 insertions(+), 19 deletions(-) diff --git a/lit_nlp/client/modules/datapoint_editor_module.css b/lit_nlp/client/modules/datapoint_editor_module.css index 00a26292..78b16d67 100644 --- a/lit_nlp/client/modules/datapoint_editor_module.css +++ b/lit_nlp/client/modules/datapoint_editor_module.css @@ -133,6 +133,12 @@ select.dropdown { min-width: 50px; } +#edit-table { + display: flex; + flex-direction: column; + height: 100%; +} + .entry { align-items: center; box-sizing: border-box; @@ -144,6 +150,15 @@ select.dropdown { border-radius: 4px; } +.entry-long { + flex-grow: 1; + flex-basis: 100%; + flex-direction: column; + align-items: flex-start; + justify-content: flex-start; + flex-flow: column; +} + .entry-edited { background: var(--lit-bric-50); } @@ -154,12 +169,23 @@ select.dropdown { justify-content: flex-end; } -.entry-content-long { +.entry-medium .entry-content { flex-grow: 1; - flex-basis: 100%; + max-width: max(50%, 500px); +} + +.entry-long .entry-content { + flex-direction: column; + justify-content: flex-start; + height: 100%; + width: 100%; +} + +.entry-long .entry-content textarea { + height: 100%; } -.entry-content.left-align { +.entry-left-align .entry-content { justify-content: flex-start; } @@ -170,7 +196,6 @@ input { textarea { resize: vertical; - height: 100px; font-family: 'Roboto', sans; line-height: 18px; color: var(--lit-gray-800); diff --git a/lit_nlp/client/modules/datapoint_editor_module.ts b/lit_nlp/client/modules/datapoint_editor_module.ts index a27460e3..14dc7514 100644 --- a/lit_nlp/client/modules/datapoint_editor_module.ts +++ b/lit_nlp/client/modules/datapoint_editor_module.ts @@ -382,7 +382,7 @@ export class DatapointEditorModule extends LitModule { `; const compareButton = html` - @@ -524,7 +524,7 @@ export class DatapointEditorModule extends LitModule { const errorInputClasses = renderError ? 'error-input' : ''; const errorIconClasses = renderError ? 'error-icon' : ''; - const inputStyle = {'height': this.inputHeights[key]}; + const inputStyle = {'min-height': this.inputHeights[key]}; // Render a multi-line text input. const renderFreeformInput = () => html`